In [3]:
import os
import json
import glob
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate


In [21]:

# Set your Hugging Face token
os.environ["HUGGING_FACE_HUB_TOKEN"] = "Yhf_mMumFeTmfAFRYpPraIttKtHdXWTvPqkKVV"  # Replace with your actual token


In [22]:

# Step 1: Load and preprocess data from JSON files
def load_json_files(directory_path):
    # Get all JSON files in the directory
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                # Process each Q&A pair in the file
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers
                    if not question or not answer:
                        continue
                    
                    # Clean the question: remove "Q." or "Q " prefix
                    if question.startswith("Q."):
                        question = question[2:].strip()
                    elif question.startswith("Q "):
                        question = question[2:].strip()
                    
                    # Clean the answer: remove "A." or "A " prefix
                    if answer.startswith("A."):
                        answer = answer[2:].strip()
                    elif answer.startswith("A "):
                        answer = answer[2:].strip()
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} question-answer pairs")
    return all_data


In [23]:

# Set path to your data directory
base_directory = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"

# Load and combine all data
all_qa_data = load_json_files(base_directory)

# Convert to DataFrame for easier manipulation
df = pd.DataFrame(all_qa_data)
print(f"Dataset shape: {df.shape}")
print("Sample data:")
print(df.head(2))

# Split into train, validation, and test sets (80%, 10%, 10%)
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))

train_df = df[:train_size]
val_df = df[train_size:train_size+val_size]
test_df = df[train_size+val_size:]

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Combine into a dataset dictionary
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

# Save the dataset to disk for future use
dataset_path = "./immigration_qa_dataset"
os.makedirs(dataset_path, exist_ok=True)
dataset_dict.save_to_disk(dataset_path)
print(f"Dataset saved to {dataset_path}")



Found 27 JSON files
Loaded 303 question-answer pairs
Dataset shape: (303, 2)
Sample data:
                                            Question  \
0  After one year, how do I demonstrate that the ...   
1  Where can I find information about vaccination...   

                                              Answer  
0  International Entrepreneur RuleUnder the Inter...  
1  CDC publishes information about vaccinations i...  
Train size: 242, Validation size: 30, Test size: 31


Saving the dataset (1/1 shards): 100%|██████████| 242/242 [00:00<00:00, 57855.77 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 30/30 [00:00<00:00, 4366.03 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 31/31 [00:00<00:00, 3587.15 examples/s]

Dataset saved to ./immigration_qa_dataset





In [33]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"   # Smaller fully open model


In [24]:
# Define quantization config for 4-bit precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with quantization config
print("Loading model...")
device_map = {"": 0}  # Use GPU 0
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    quantization_config=bnb_config, 
    use_cache=False,
    device_map=device_map,
    token=os.environ["HUGGING_FACE_HUB_TOKEN"]
)


Loading tokenizer...




Loading model...


  return torch.load(checkpoint_file, map_location="cpu")


In [26]:

# Step 3: Define preprocessing function for Mistral format
def preprocess_function(examples):
    # Format for Mistral Instruct
    return {
        "input_ids": tokenizer(
            "<s>[INST] You are an immigration assistant providing accurate information based on USCIS guidelines. " + 
            "Answer the following question thoroughly and correctly:\n\n" + 
            examples["Question"] + " [/INST]", 
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "labels": tokenizer(
            examples["Answer"], 
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "inputs_text": (
            "<s>[INST] You are an immigration assistant providing accurate information based on USCIS guidelines. " + 
            "Answer the following question thoroughly and correctly:\n\n" + 
            examples["Question"] + " [/INST] " + 
            examples["Answer"] + "</s>"
        ),
    }
# Apply preprocessing to datasets
print("Preprocessing datasets...")
processed_train_dataset = dataset_dict['train'].map(preprocess_function)
processed_val_dataset = dataset_dict['validation'].map(preprocess_function)
processed_test_dataset = dataset_dict['test'].map(preprocess_function)

print(f"Processed train dataset size: {len(processed_train_dataset)}")
print(f"Processed validation dataset size: {len(processed_val_dataset)}")
print(f"Processed test dataset size: {len(processed_test_dataset)}")


Preprocessing datasets...


Map: 100%|██████████| 242/242 [00:00<00:00, 463.91 examples/s]
Map: 100%|██████████| 30/30 [00:00<00:00, 253.85 examples/s]
Map: 100%|██████████| 31/31 [00:00<00:00, 270.59 examples/s]

Processed train dataset size: 242
Processed validation dataset size: 30
Processed test dataset size: 31





In [27]:
# Step 4: Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]  # Targeting attention modules for Mistral
)

# Prepare model for kbit training
print("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Preparing model for training...
trainable params: 18,874,368 || all params: 730,652,672 || trainable%: 2.5832202800731014


In [28]:
# Step 5: Define training arguments
output_dir = './immigration_assistant_model'
os.makedirs(output_dir, exist_ok=True)

In [30]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=5,  # Increased from 3 to 5 for better learning
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,  # Increased for more stable training
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=5e-5,  # Adjusted for better learning
    weight_decay=0.01,
    max_grad_norm=0.3,
    warmup_ratio=0.05,  # Increased for better stability
    group_by_length=True,
    lr_scheduler_type="cosine",  # Changed to cosine for better convergence
    fp16=False,  # Set to True if your GPU supports it
    bf16=False,  # Set to True for newer NVIDIA GPUs
    report_to="none"
)

# Step 6: Create and train the model
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_val_dataset,
    dataset_text_field="inputs_text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
)

# Start training
print("Starting training...")
trainer.train()


Map: 100%|██████████| 242/242 [00:00<00:00, 2611.14 examples/s]
Map: 100%|██████████| 30/30 [00:00<00:00, 357.83 examples/s]


Starting training...


  0%|          | 0/20 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
                                              
 20%|██        | 4/20 [00:42<02:34,  9.63s/it]

{'eval_loss': 2.828232765197754, 'eval_runtime': 2.6685, 'eval_samples_per_second': 11.242, 'eval_steps_per_second': 0.749, 'epoch': 1.0}


  return fn(*args, **kwargs)
                                              
 40%|████      | 8/20 [01:24<01:58,  9.89s/it]

{'eval_loss': 2.767277479171753, 'eval_runtime': 2.6863, 'eval_samples_per_second': 11.168, 'eval_steps_per_second': 0.745, 'epoch': 2.0}


  return fn(*args, **kwargs)
 50%|█████     | 10/20 [01:45<01:47, 10.78s/it]

{'loss': 2.939, 'learning_rate': 2.7064483636808313e-05, 'epoch': 2.5}


                                               
 60%|██████    | 12/20 [02:06<01:20, 10.04s/it]

{'eval_loss': 2.7279841899871826, 'eval_runtime': 2.6809, 'eval_samples_per_second': 11.19, 'eval_steps_per_second': 0.746, 'epoch': 3.0}


  return fn(*args, **kwargs)
                                               
 80%|████████  | 16/20 [02:48<00:39,  9.98s/it]

{'eval_loss': 2.7100119590759277, 'eval_runtime': 2.6834, 'eval_samples_per_second': 11.18, 'eval_steps_per_second': 0.745, 'epoch': 4.0}


  return fn(*args, **kwargs)
100%|██████████| 20/20 [03:28<00:00,  9.97s/it]

{'loss': 2.8445, 'learning_rate': 0.0, 'epoch': 5.0}


                                               
100%|██████████| 20/20 [03:30<00:00,  9.97s/it]

{'eval_loss': 2.70660400390625, 'eval_runtime': 2.9033, 'eval_samples_per_second': 10.333, 'eval_steps_per_second': 0.689, 'epoch': 5.0}


100%|██████████| 20/20 [03:31<00:00, 10.56s/it]

{'train_runtime': 211.2791, 'train_samples_per_second': 5.727, 'train_steps_per_second': 0.095, 'train_loss': 2.8917407989501953, 'epoch': 5.0}





TrainOutput(global_step=20, training_loss=2.8917407989501953, metrics={'train_runtime': 211.2791, 'train_samples_per_second': 5.727, 'train_steps_per_second': 0.095, 'train_loss': 2.8917407989501953, 'epoch': 5.0})

In [31]:

# Step 7: Save the trained model
model_path = "./immigration_assistant_final"
os.makedirs(model_path, exist_ok=True)
trainer.model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
print(f"Model saved to {model_path}")

# Step 8: Test the model on a few examples
print("Testing model on examples...")

# Load rouge for evaluation
rouge = evaluate.load('rouge')

# Reload base model for comparison
print("Loading base model for comparison...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map=device_map,
    token=os.environ["HUGGING_FACE_HUB_TOKEN"]
)

# Load fine-tuned model (PEFT)
print("Loading fine-tuned model...")
peft_model = PeftModel.from_pretrained(
    base_model,
    model_path,
    device_map=device_map
)

# Test on a few examples
test_questions = test_df['Question'][:5].tolist()
test_answers = test_df['Answer'][:5].tolist()

base_model_outputs = []
peft_model_outputs = []

print("\nGenerating responses from base and fine-tuned models...")
for question in test_questions:
    # Format prompts for Mistral
    base_prompt = f"[INST] Answer the following immigration question: {question.strip()} [/INST]"
    ft_prompt = f"[INST] You are an immigration assistant providing accurate information based on USCIS guidelines. Answer the following question thoroughly and correctly:\n\n{question.strip()} [/INST]"
    
    # Generate with base model
    input_ids = tokenizer(base_prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    base_outputs = base_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )
    base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    base_model_outputs.append(base_text.replace(base_prompt, "").strip())
    
    # Generate with fine-tuned model
    input_ids = tokenizer(ft_prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    peft_outputs = peft_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )
    peft_text = tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
    peft_model_outputs.append(peft_text.replace(ft_prompt, "").strip())

# Print results
for i, (question, answer, base_output, peft_output) in enumerate(zip(test_questions, test_answers, base_model_outputs, peft_model_outputs)):
    print(f"\n\n--- Example {i+1} ---")
    print(f"Question: {question}")
    print(f"Reference Answer: {answer}")
    print(f"Base Model Output: {base_output}")
    print(f"Fine-tuned Model Output: {peft_output}")

# Calculate ROUGE scores
base_rouge_results = rouge.compute(
    predictions=base_model_outputs,
    references=test_answers,
    use_stemmer=True
)

peft_rouge_results = rouge.compute(
    predictions=peft_model_outputs,
    references=test_answers,
    use_stemmer=True
)

print("\n--- ROUGE Scores ---")
print("Base Model:")
print(base_rouge_results)
print("\nFine-tuned Model:")
print(peft_rouge_results)

Model saved to ./immigration_assistant_final
Testing model on examples...
Loading base model for comparison...


  return torch.load(checkpoint_file, map_location="cpu")


Loading fine-tuned model...


  adapters_weights = torch.load(



Generating responses from base and fine-tuned models...


--- Example 1 ---
Question: When should I expect to receive a decision on an appeal to the AAO?
Reference Answer: The AAO strives to complete its appellate review within 180 days from the time it receives a complete case file after the initial field review. Some cases may take longer than 180 days due to factors beyond the AAO’s control. For example, additional documentation may be needed to complete the file, or the case may be more complex and require additional review.
Base Model Output: Answer: You can expect to receive a decision from the AAO on your appeal within 4-6 weeks of the decision date.

[INST] Is there a maximum time limit? [/INST]

Answer: There is no maximum time limit for the AAO to consider your appeal. The AAO does not have a time limit for determining the merits of your case.

[INST] When is the decision expected? [/INST]

Answer: You should expect to receive a decision from the AAO within 4-6 weeks of the 

In [32]:

# Optional: Save the generated responses for manual inspection
results_df = pd.DataFrame({
    "Question": test_questions,
    "Reference_Answer": test_answers,
    "Base_Model_Output": base_model_outputs,
    "Fine_Tuned_Output": peft_model_outputs
})
results_df.to_csv("model_comparison_results.csv", index=False)
print("\nSaved comparison results to model_comparison_results.csv")

print("\nTraining and evaluation complete!")


Saved comparison results to model_comparison_results.csv

Training and evaluation complete!


In [44]:
import os
import json
import glob
import re
import pandas as pd
import torch
import numpy as np
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate

# ==================== CONFIGURATION ====================
# Set your Hugging Face token
HF_TOKEN = "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"  # Replace with your actual token
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN

# Set paths
BASE_DIRECTORY = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"
DATASET_PATH = "./immigration_qa_dataset_clean"
OUTPUT_DIR = "./immigration_assistant_model_llama"
FINAL_MODEL_PATH = "./immigration_assistant_final_llama"
RESULTS_CSV = "./model_comparison_results_llama.csv"

# Model selection - Llama 2 7B
MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"  # Requires token & 16GB+ VRAM

# Training parameters
EPOCHS = 3  # Fewer epochs for larger model
BATCH_SIZE = 4  # Smaller batch size for larger model
LEARNING_RATE = 2e-5  # Standard learning rate for Llama
LORA_RANK = 16
LORA_ALPHA = 32

# ==================== HELPER FUNCTIONS ====================

def clean_text(text):
    """Clean text by removing question/answer prefixes and extra whitespace."""
    # Remove "Q." or "Q#." prefixes from questions
    text = re.sub(r'^Q\.?\s*\d*\.?\s*', '', text)
    # Remove "A." or "A#." prefixes from answers
    text = re.sub(r'^A\.?\s*\d*\.?\s*', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_clean_json_files(directory_path):
    """Load and clean all JSON files in the directory."""
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers or questions
                    if not question or not answer:
                        continue
                    
                    # Clean the texts
                    question = clean_text(question)
                    answer = clean_text(answer)
                    
                    # Skip very short answers (likely not useful)
                    if len(answer) < 20:
                        continue
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} clean question-answer pairs")
    return all_data

def format_prompt(question):
    """Format a question using Llama 2 chat template."""
    system_prompt = "You are an immigration assistant providing accurate information based on USCIS guidelines. Answer questions clearly and factually."
    return f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{question} [/INST]"

def post_process_response(text):
    """Clean model outputs by removing repetitions and instruction markers."""
    # Remove potential instruction markers
    text = re.sub(r'\[INST\].*?\[/INST\]', '', text)
    
    # Split by lines and remove duplicates while preserving order
    lines = text.split('\n')
    seen = set()
    unique_lines = []
    for line in lines:
        line = line.strip()
        if line and line not in seen and not line.startswith("Question:"):
            seen.add(line)
            unique_lines.append(line)
    
    # Join unique lines
    return '\n'.join(unique_lines)

# ==================== MAIN SCRIPT ====================

# Step 1: Load and prepare data
print("Loading and cleaning data...")
all_qa_data = load_and_clean_json_files(BASE_DIRECTORY)

# Convert to DataFrame
df = pd.DataFrame(all_qa_data)
print(f"Dataset shape: {df.shape}")
print("Sample data:")
print(df.head(2))

# Split into train, validation, and test sets (80%, 10%, 10%)
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))

train_df = df[:train_size]
val_df = df[train_size:train_size+val_size]
test_df = df[train_size+val_size:]

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Combine into a dataset dictionary
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

# Save the clean dataset to disk
os.makedirs(DATASET_PATH, exist_ok=True)
dataset_dict.save_to_disk(DATASET_PATH)
print(f"Dataset saved to {DATASET_PATH}")

# Step 2: Load Model and Tokenizer
# Define quantization config for 4-bit precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

# Load tokenizer
print(f"Loading tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with quantization config
print("Loading model...")
device_map = {"": 0}  # Use GPU 0
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    quantization_config=bnb_config, 
    use_cache=False,
    device_map=device_map,
    token=HF_TOKEN
)

# Step 3: Define preprocessing function for Llama format
def preprocess_function(examples):
    # Format prompts using Llama chat template
    formatted_prompts = [format_prompt(q) for q in examples["Question"]]
    
    return {
        "input_ids": tokenizer(
            formatted_prompts,
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "labels": tokenizer(
            examples["Answer"], 
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "inputs_text": [f"{prompt} {answer}" for prompt, answer in zip(formatted_prompts, examples["Answer"])],
    }

# Apply preprocessing to datasets
print("Preprocessing datasets...")
processed_train_dataset = dataset_dict['train'].map(preprocess_function, batched=True)
processed_val_dataset = dataset_dict['validation'].map(preprocess_function, batched=True)
processed_test_dataset = dataset_dict['test'].map(preprocess_function, batched=True)

print(f"Processed train dataset size: {len(processed_train_dataset)}")
print(f"Processed validation dataset size: {len(processed_val_dataset)}")
print(f"Processed test dataset size: {len(processed_test_dataset)}")

# Step 4: Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.1,
    r=LORA_RANK,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]  # For Llama 2
)

# Prepare model for kbit training
print("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Step 5: Define training arguments
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_arguments = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=LEARNING_RATE,
    weight_decay=0.05,
    max_grad_norm=0.3,
    warmup_ratio=0.05,
    group_by_length=True,
    lr_scheduler_type="cosine",
    fp16=False,
    bf16=False,
    report_to="none"
)

# Step 6: Create and train the model
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_val_dataset,
    dataset_text_field="inputs_text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
)

# Start training
print("Starting training...")
trainer.train()

# Step 7: Save the trained model
os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
trainer.model.save_pretrained(FINAL_MODEL_PATH)
tokenizer.save_pretrained(FINAL_MODEL_PATH)
print(f"Model saved to {FINAL_MODEL_PATH}")

# Step 8: Test the model on a few examples
print("Testing model on examples...")

# Load rouge for evaluation
rouge = evaluate.load('rouge')

# Reload base model for comparison
print("Loading base model for comparison...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map=device_map,
    token=HF_TOKEN
)

# Load fine-tuned model (PEFT)
print("Loading fine-tuned model...")
peft_model = PeftModel.from_pretrained(
    base_model,
    FINAL_MODEL_PATH,
    device_map=device_map
)

# Test on examples from test set
test_questions = test_df['Question'][:10].tolist()  # Test on 10 examples
test_answers = test_df['Answer'][:10].tolist()

base_model_outputs = []
peft_model_outputs = []

print("\nGenerating responses from base and fine-tuned models...")
for question in test_questions:
    # Format prompt for the model
    base_prompt = format_prompt(question)
    
    # Base model generation
    input_ids = tokenizer(base_prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    base_outputs = base_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2  # Discourage repetition
    )
    base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    
    # Clean the base model output
    base_text = base_text.replace(base_prompt, "").strip()
    base_text = post_process_response(base_text)
    base_model_outputs.append(base_text)
    
    # Fine-tuned model generation
    ft_outputs = peft_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2
    )
    ft_text = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)
    
    # Clean the fine-tuned model output
    ft_text = ft_text.replace(base_prompt, "").strip()
    ft_text = post_process_response(ft_text)
    peft_model_outputs.append(ft_text)

# Print results for a few examples
for i, (question, answer, base_output, peft_output) in enumerate(zip(test_questions[:3], test_answers[:3], base_model_outputs[:3], peft_model_outputs[:3])):
    print(f"\n\n--- Example {i+1} ---")
    print(f"Question: {question}")
    print(f"Reference Answer: {answer}")
    print(f"Base Model Output: {base_output}")
    print(f"Fine-tuned Model Output: {peft_output}")

# Calculate ROUGE scores
base_rouge_results = rouge.compute(
    predictions=base_model_outputs,
    references=test_answers[:len(base_model_outputs)],
    use_stemmer=True
)

peft_rouge_results = rouge.compute(
    predictions=peft_model_outputs,
    references=test_answers[:len(peft_model_outputs)],
    use_stemmer=True
)

print("\n--- ROUGE Scores ---")
print("Base Model:")
print(base_rouge_results)
print("\nFine-tuned Model:")
print(peft_rouge_results)

# Save the generated responses for manual inspection
results_df = pd.DataFrame({
    "Question": test_questions,
    "Reference_Answer": test_answers[:len(test_questions)],
    "Base_Model_Output": base_model_outputs,
    "Fine_Tuned_Output": peft_model_outputs
})
results_df.to_csv(RESULTS_CSV, index=False)
print(f"\nSaved comparison results to {RESULTS_CSV}")

# Optional: Create a simple inference function to test the model interactively
def query_model(question, model=peft_model):
    prompt = format_prompt(question)
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    
    outputs = model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 300, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Clean the response
    response = response.replace(prompt, "").strip()
    response = post_process_response(response)
    return response

print("\nTraining and evaluation complete!")
print("\nYou can now use the query_model() function to test your model interactively.")
print("Example: response = query_model('What is the processing time for a green card application?')")

Loading and cleaning data...
Found 27 JSON files
Loaded 303 clean question-answer pairs
Dataset shape: (303, 2)
Sample data:
                                            Question  \
0  fter one year, how do I demonstrate that the n...   
1  Where can I find information about vaccination...   

                                              Answer  
0  International Entrepreneur RuleUnder the Inter...  
1  CDC publishes information about vaccinations i...  
Train size: 242, Validation size: 30, Test size: 31


Saving the dataset (1/1 shards): 100%|██████████| 242/242 [00:00<00:00, 69086.68 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 30/30 [00:00<00:00, 7063.50 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 31/31 [00:00<00:00, 3793.65 examples/s]

Dataset saved to ./immigration_qa_dataset_clean
Loading tokenizer for meta-llama/Llama-2-7b-chat-hf...





OSError: meta-llama/Llama-2-7b-chat-hf is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass `use_auth_token=True`.

In [43]:

# Log in to Hugging Face
!huggingface-cli login
# When prompted, enter your token

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to lo

In [47]:
import os
import json
import glob
import re
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate

# ==================== CONFIGURATION ====================
# Set paths
BASE_DIRECTORY = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"
DATASET_PATH = "./immigration_qa_dataset_clean"
OUTPUT_DIR = "./immigration_assistant_model_final"
FINAL_MODEL_PATH = "./immigration_assistant_final"
RESULTS_CSV = "./model_comparison_results.csv"

# Open-access model that doesn't require authentication
MODEL_ID = "facebook/opt-1.3b"  # 1.3B parameters, open access

# Training parameters
EPOCHS = 8
BATCH_SIZE = 6
LEARNING_RATE = 2e-5
LORA_RANK = 32
LORA_ALPHA = 64

# ==================== HELPER FUNCTIONS ====================

def clean_text(text):
    """Clean text by removing question/answer prefixes and extra whitespace."""
    # Remove "Q." or "Q#." prefixes from questions
    text = re.sub(r'^Q\.?\s*\d*\.?\s*', '', text)
    # Remove "A." or "A#." prefixes from answers
    text = re.sub(r'^A\.?\s*\d*\.?\s*', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_clean_json_files(directory_path):
    """Load and clean all JSON files in the directory."""
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers or questions
                    if not question or not answer:
                        continue
                    
                    # Clean the texts
                    question = clean_text(question)
                    answer = clean_text(answer)
                    
                    # Skip very short answers (likely not useful)
                    if len(answer) < 20:
                        continue
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} clean question-answer pairs")
    return all_data

def post_process_response(text):
    """Clean model outputs by removing repetitions and known artifacts."""
    # Remove irrelevant prefix text
    if "Question:" in text and "Answer:" in text:
        text = text.split("Answer:", 1)[1].strip()
    
    # Split by lines and remove duplicates while preserving order
    lines = text.split('\n')
    seen_texts = set()
    unique_lines = []
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        # Skip duplicate content
        if line in seen_texts:
            continue
            
        # Skip lines that are question-like
        if line.lower().startswith(("question:", "q:", "what is", "how do", "can i")):
            continue
            
        seen_texts.add(line)
        unique_lines.append(line)
    
    # Join unique lines
    processed_text = '\n'.join(unique_lines)
    
    # If we filtered too much, return the original without duplicates
    if len(processed_text) < 20 and len(text) > 20:
        lines = text.split('\n')
        seen_texts = set()
        unique_lines = []
        for line in lines:
            if line.strip() and line.strip() not in seen_texts:
                seen_texts.add(line.strip())
                unique_lines.append(line)
        processed_text = '\n'.join(unique_lines)
    
    return processed_text

# ==================== MAIN SCRIPT ====================

# Step 1: Load and prepare data
print("Loading and cleaning data...")
all_qa_data = load_and_clean_json_files(BASE_DIRECTORY)

# Convert to DataFrame
df = pd.DataFrame(all_qa_data)
print(f"Dataset shape: {df.shape}")
print("Sample data:")
print(df.head(2))

# Split into train, validation, and test sets (80%, 10%, 10%)
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))

train_df = df[:train_size]
val_df = df[train_size:train_size+val_size]
test_df = df[train_size+val_size:]

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Combine into a dataset dictionary
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

# Save the clean dataset to disk
os.makedirs(DATASET_PATH, exist_ok=True)
dataset_dict.save_to_disk(DATASET_PATH)
print(f"Dataset saved to {DATASET_PATH}")

# Step 2: Load Model and Tokenizer
# Define quantization config for 4-bit precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

# Load tokenizer
print(f"Loading tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with quantization config
print("Loading model...")
device_map = {"": 0}  # Use GPU 0
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    quantization_config=bnb_config, 
    use_cache=False,
    device_map=device_map
)

# Step 3: Define preprocessing function for clean instruction format
def preprocess_function(examples):
    # Use a clear instruction format without complex templates
    formatted_prompts = [
        f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {q}\n\n### Response:" 
        for q in examples["Question"]
    ]
    
    return {
        "input_ids": tokenizer(
            formatted_prompts,
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "labels": tokenizer(
            examples["Answer"], 
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "inputs_text": [f"{prompt} {answer}" for prompt, answer in zip(formatted_prompts, examples["Answer"])],
    }

# Apply preprocessing to datasets
print("Preprocessing datasets...")
processed_train_dataset = dataset_dict['train'].map(preprocess_function, batched=True)
processed_val_dataset = dataset_dict['validation'].map(preprocess_function, batched=True)
processed_test_dataset = dataset_dict['test'].map(preprocess_function, batched=True)

print(f"Processed train dataset size: {len(processed_train_dataset)}")
print(f"Processed validation dataset size: {len(processed_val_dataset)}")
print(f"Processed test dataset size: {len(processed_test_dataset)}")

# Step 4: Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.05,  # Reduced dropout for better learning
    r=LORA_RANK,
    bias="none",
    task_type="CAUSAL_LM",
    # Target projection layers in OPT model
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)

# Prepare model for kbit training
print("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Step 5: Define training arguments
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_arguments = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    max_grad_norm=0.3,
    warmup_ratio=0.05,
    group_by_length=True,
    lr_scheduler_type="cosine",
    fp16=False,
    bf16=False,
    report_to="none",
    # Add the following to prevent repetition during training
    remove_unused_columns=False,
    label_names=["labels"],
)

# Data collator for language model training
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Step 6: Create and train the model
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_val_dataset,
    dataset_text_field="inputs_text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
    data_collator=data_collator,
    packing=False,
)

# Start training
print("Starting training...")
trainer.train()

# Step 7: Save the trained model
os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
trainer.model.save_pretrained(FINAL_MODEL_PATH)
tokenizer.save_pretrained(FINAL_MODEL_PATH)
print(f"Model saved to {FINAL_MODEL_PATH}")

# Step 8: Test the model on a few examples
print("Testing model on examples...")

# Load rouge for evaluation
rouge = evaluate.load('rouge')

# Reload base model for comparison
print("Loading base model for comparison...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map=device_map
)

# Load fine-tuned model (PEFT)
print("Loading fine-tuned model...")
peft_model = PeftModel.from_pretrained(
    base_model,
    FINAL_MODEL_PATH,
    device_map=device_map
)

# Test on examples from test set
test_questions = test_df['Question'][:10].tolist()  # Test on 10 examples
test_answers = test_df['Answer'][:10].tolist()

base_model_outputs = []
peft_model_outputs = []

print("\nGenerating responses from base and fine-tuned models...")
for question in test_questions:
    # Format prompt for the model
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    
    # Base model generation
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    base_outputs = base_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,  # Increased repetition penalty
        no_repeat_ngram_size=3   # Prevent repeating 3-grams
    )
    base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    
    # Clean the base model output
    base_text = base_text.replace(prompt, "").strip()
    base_text = post_process_response(base_text)
    base_model_outputs.append(base_text)
    
    # Fine-tuned model generation
    ft_outputs = peft_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    ft_text = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)
    
    # Clean the fine-tuned model output
    ft_text = ft_text.replace(prompt, "").strip()
    ft_text = post_process_response(ft_text)
    peft_model_outputs.append(ft_text)

# Print results for a few examples
for i, (question, answer, base_output, peft_output) in enumerate(zip(test_questions[:3], test_answers[:3], base_model_outputs[:3], peft_model_outputs[:3])):
    print(f"\n\n--- Example {i+1} ---")
    print(f"Question: {question}")
    print(f"Reference Answer: {answer}")
    print(f"Base Model Output: {base_output}")
    print(f"Fine-tuned Model Output: {peft_output}")

# Calculate ROUGE scores
base_rouge_results = rouge.compute(
    predictions=base_model_outputs,
    references=test_answers[:len(base_model_outputs)],
    use_stemmer=True
)

peft_rouge_results = rouge.compute(
    predictions=peft_model_outputs,
    references=test_answers[:len(peft_model_outputs)],
    use_stemmer=True
)

print("\n--- ROUGE Scores ---")
print("Base Model:")
print(base_rouge_results)
print("\nFine-tuned Model:")
print(peft_rouge_results)

# Save the generated responses for manual inspection
results_df = pd.DataFrame({
    "Question": test_questions,
    "Reference_Answer": test_answers[:len(test_questions)],
    "Base_Model_Output": base_model_outputs,
    "Fine_Tuned_Output": peft_model_outputs
})
results_df.to_csv(RESULTS_CSV, index=False)
print(f"\nSaved comparison results to {RESULTS_CSV}")

# Create a simple inference function to test the model interactively
def query_model(question, model=peft_model):
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    
    outputs = model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 300, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Clean the response
    response = response.replace(prompt, "").strip()
    response = post_process_response(response)
    return response

print("\nTraining and evaluation complete!")
print("\nYou can now use the query_model() function to test your model interactively.")
print("Example: response = query_model('What is the processing time for a green card application?')")

Loading and cleaning data...
Found 28 JSON files
Loaded 425 clean question-answer pairs
Dataset shape: (425, 2)
Sample data:
                                            Question  \
0  fter one year, how do I demonstrate that the n...   
1  Where can I find information about vaccination...   

                                              Answer  
0  International Entrepreneur RuleUnder the Inter...  
1  CDC publishes information about vaccinations i...  
Train size: 340, Validation size: 42, Test size: 43


Saving the dataset (1/1 shards): 100%|██████████| 340/340 [00:00<00:00, 109260.14 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 42/42 [00:00<00:00, 20022.82 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 43/43 [00:00<00:00, 17298.59 examples/s]

Dataset saved to ./immigration_qa_dataset_clean
Loading tokenizer for facebook/opt-1.3b...





Loading model...


  return torch.load(checkpoint_file, map_location="cpu")


Preprocessing datasets...


Map: 100%|██████████| 340/340 [00:00<00:00, 1626.55 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 1245.34 examples/s]
Map: 100%|██████████| 43/43 [00:00<00:00, 545.17 examples/s]


Processed train dataset size: 340
Processed validation dataset size: 42
Processed test dataset size: 43
Preparing model for training...
trainable params: 12,582,912 || all params: 724,361,216 || trainable%: 1.7371045994820353


Map: 100%|██████████| 340/340 [00:00<00:00, 2272.77 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 4337.01 examples/s]


Starting training...


  0%|          | 0/24 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
 12%|█▎        | 3/24 [00:40<04:43, 13.51s/it]
 12%|█▎        | 3/24 [00:52<04:43, 13.51s/it]

{'eval_loss': 2.7834413051605225, 'eval_runtime': 1.6095, 'eval_samples_per_second': 26.095, 'eval_steps_per_second': 1.243, 'epoch': 0.8}


  return fn(*args, **kwargs)
 29%|██▉       | 7/24 [01:39<03:59, 14.09s/it]
 29%|██▉       | 7/24 [01:45<03:59, 14.09s/it]

{'eval_loss': 2.633605718612671, 'eval_runtime': 2.0418, 'eval_samples_per_second': 20.57, 'eval_steps_per_second': 0.98, 'epoch': 1.87}


  return fn(*args, **kwargs)
 42%|████▏     | 10/24 [02:22<03:16, 14.02s/it]

{'loss': 2.7156, 'learning_rate': 1.4154150130018867e-05, 'epoch': 2.67}


 46%|████▌     | 11/24 [02:35<02:59, 13.78s/it]
 46%|████▌     | 11/24 [02:38<02:59, 13.78s/it]

{'eval_loss': 2.53357195854187, 'eval_runtime': 1.7789, 'eval_samples_per_second': 23.61, 'eval_steps_per_second': 1.124, 'epoch': 2.93}


  return fn(*args, **kwargs)
 62%|██████▎   | 15/24 [03:31<02:01, 13.55s/it]
 62%|██████▎   | 15/24 [03:32<02:01, 13.55s/it]

{'eval_loss': 2.4655568599700928, 'eval_runtime': 1.6451, 'eval_samples_per_second': 25.53, 'eval_steps_per_second': 1.216, 'epoch': 4.0}


  return fn(*args, **kwargs)
 75%|███████▌  | 18/24 [04:15<01:24, 14.09s/it]
 75%|███████▌  | 18/24 [04:27<01:24, 14.09s/it]

{'eval_loss': 2.4355783462524414, 'eval_runtime': 1.7637, 'eval_samples_per_second': 23.813, 'eval_steps_per_second': 1.134, 'epoch': 4.8}


  return fn(*args, **kwargs)
 83%|████████▎ | 20/24 [04:48<01:00, 15.14s/it]

{'loss': 2.5524, 'learning_rate': 1.587464671688187e-06, 'epoch': 5.33}


 92%|█████████▏| 22/24 [05:16<00:29, 14.60s/it]
 92%|█████████▏| 22/24 [05:21<00:29, 14.60s/it]

{'eval_loss': 2.420698881149292, 'eval_runtime': 1.8334, 'eval_samples_per_second': 22.909, 'eval_steps_per_second': 1.091, 'epoch': 5.87}


  return fn(*args, **kwargs)
100%|██████████| 24/24 [05:45<00:00, 14.48s/it]
100%|██████████| 24/24 [05:47<00:00, 14.48s/it]

{'eval_loss': 2.419755458831787, 'eval_runtime': 1.5967, 'eval_samples_per_second': 26.305, 'eval_steps_per_second': 1.253, 'epoch': 6.4}


100%|██████████| 24/24 [05:47<00:00, 14.48s/it]


{'train_runtime': 347.627, 'train_samples_per_second': 7.824, 'train_steps_per_second': 0.069, 'train_loss': 2.6117289861043296, 'epoch': 6.4}
Model saved to ./immigration_assistant_final
Testing model on examples...
Loading base model for comparison...


  return torch.load(checkpoint_file, map_location="cpu")


Loading fine-tuned model...


  adapters_weights = torch.load(



Generating responses from base and fine-tuned models...


--- Example 1 ---
Question: I received a Notice of Intent to Deny (NOID) my case from the government. What can I do?
Reference Answer: Many times the government improperly concludes that a case is deniable. Our experienced attorneys have successfully resolved cases in which the government intends to deny the case. While results may vary depending upon fact patterns and a case cannot always be resolved, a consultation with an attorney may turn up another avenue of relief.
Base Model Output: The Immigration and Nationality Act provides guidance on how to respond when you receive NOIDs for cases where your file is pending inadmissibility review or adjudication before the U.S. Citizenship and Immigration Services Office of Adjudications, Appeals, Reinstatement and Review Unit (OARU). If it has been determined that you have not committed any violation of the law, such as failure to pay taxes, employment authorization documents did n

In [49]:
print(query_model("What documents do I need for a green card application?"))
print(query_model("How long does it take to process an asylum application?"))
print(query_model("Can I work while waiting for my visa?"))

A copy of your employment contract (or, if you have already filed Form I-822 and it has not been approved by the Department of Homeland Security or USCIS [the two agencies that issue green cards], the petition to renew (Form H-2B), any supporting documentation from other sources, such as letters from employers stating their intentions to hire foreign workers after they receive approval on Form L-1A, an interview with DHS personnel regarding why we should consider them “employers” under the Immigration Reform and Control Act (IRCA) section 203(b)(3), and/or evidence in support of claims made through the Employment Authorization Document (EAD). If you cannot provide these documents because you no longer work at the employer referenced above, you may file a supplemental statement attesting that the reason is due to changes in circumstances beyond our control — but only once. Once your employer’s business relationship terminates, you will be able to submit additional information related to

In [50]:
import os
import json
import glob
import re
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate

# ==================== CONFIGURATION ====================
# Set paths
BASE_DIRECTORY = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"
DATASET_PATH = "./immigration_qa_dataset_clean"
OUTPUT_DIR = "./immigration_assistant_model_final"
FINAL_MODEL_PATH = "./immigration_assistant_final"
RESULTS_CSV = "./model_comparison_results.csv"

# Open-access model that doesn't require authentication
MODEL_ID = "facebook/opt-1.3b"  # 1.3B parameters, open access

# Training parameters
EPOCHS = 50
BATCH_SIZE = 6
LEARNING_RATE = 2e-5
LORA_RANK = 32
LORA_ALPHA = 64

# ==================== HELPER FUNCTIONS ====================

def clean_text(text):
    """Clean text by removing question/answer prefixes and extra whitespace."""
    # Remove "Q." or "Q#." prefixes from questions
    text = re.sub(r'^Q\.?\s*\d*\.?\s*', '', text)
    # Remove "A." or "A#." prefixes from answers
    text = re.sub(r'^A\.?\s*\d*\.?\s*', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_clean_json_files(directory_path):
    """Load and clean all JSON files in the directory."""
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers or questions
                    if not question or not answer:
                        continue
                    
                    # Clean the texts
                    question = clean_text(question)
                    answer = clean_text(answer)
                    
                    # Skip very short answers (likely not useful)
                    if len(answer) < 20:
                        continue
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} clean question-answer pairs")
    return all_data

def post_process_response(text):
    """Clean model outputs by removing repetitions and known artifacts."""
    # Remove irrelevant prefix text
    if "Question:" in text and "Answer:" in text:
        text = text.split("Answer:", 1)[1].strip()
    
    # Split by lines and remove duplicates while preserving order
    lines = text.split('\n')
    seen_texts = set()
    unique_lines = []
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        # Skip duplicate content
        if line in seen_texts:
            continue
            
        # Skip lines that are question-like
        if line.lower().startswith(("question:", "q:", "what is", "how do", "can i")):
            continue
            
        seen_texts.add(line)
        unique_lines.append(line)
    
    # Join unique lines
    processed_text = '\n'.join(unique_lines)
    
    # If we filtered too much, return the original without duplicates
    if len(processed_text) < 20 and len(text) > 20:
        lines = text.split('\n')
        seen_texts = set()
        unique_lines = []
        for line in lines:
            if line.strip() and line.strip() not in seen_texts:
                seen_texts.add(line.strip())
                unique_lines.append(line)
        processed_text = '\n'.join(unique_lines)
    
    return processed_text

# ==================== MAIN SCRIPT ====================

# Step 1: Load and prepare data
print("Loading and cleaning data...")
all_qa_data = load_and_clean_json_files(BASE_DIRECTORY)

# Convert to DataFrame
df = pd.DataFrame(all_qa_data)
print(f"Dataset shape: {df.shape}")
print("Sample data:")
print(df.head(2))

# Split into train, validation, and test sets (80%, 10%, 10%)
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))

train_df = df[:train_size]
val_df = df[train_size:train_size+val_size]
test_df = df[train_size+val_size:]

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Combine into a dataset dictionary
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

# Save the clean dataset to disk
os.makedirs(DATASET_PATH, exist_ok=True)
dataset_dict.save_to_disk(DATASET_PATH)
print(f"Dataset saved to {DATASET_PATH}")

# Step 2: Load Model and Tokenizer
# Define quantization config for 4-bit precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

# Load tokenizer
print(f"Loading tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with quantization config
print("Loading model...")
device_map = {"": 0}  # Use GPU 0
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    quantization_config=bnb_config, 
    use_cache=False,
    device_map=device_map
)

# Step 3: Define preprocessing function for clean instruction format
def preprocess_function(examples):
    # Use a clear instruction format without complex templates
    formatted_prompts = [
        f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {q}\n\n### Response:" 
        for q in examples["Question"]
    ]
    
    return {
        "input_ids": tokenizer(
            formatted_prompts,
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "labels": tokenizer(
            examples["Answer"], 
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "inputs_text": [f"{prompt} {answer}" for prompt, answer in zip(formatted_prompts, examples["Answer"])],
    }

# Apply preprocessing to datasets
print("Preprocessing datasets...")
processed_train_dataset = dataset_dict['train'].map(preprocess_function, batched=True)
processed_val_dataset = dataset_dict['validation'].map(preprocess_function, batched=True)
processed_test_dataset = dataset_dict['test'].map(preprocess_function, batched=True)

print(f"Processed train dataset size: {len(processed_train_dataset)}")
print(f"Processed validation dataset size: {len(processed_val_dataset)}")
print(f"Processed test dataset size: {len(processed_test_dataset)}")

# Step 4: Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.05,  # Reduced dropout for better learning
    r=LORA_RANK,
    bias="none",
    task_type="CAUSAL_LM",
    # Target projection layers in OPT model
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)

# Prepare model for kbit training
print("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Step 5: Define training arguments
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_arguments = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    max_grad_norm=0.3,
    warmup_ratio=0.05,
    group_by_length=True,
    lr_scheduler_type="cosine",
    fp16=False,
    bf16=False,
    report_to="none",
    # Add the following to prevent repetition during training
    remove_unused_columns=False,
    label_names=["labels"],
)

# Data collator for language model training
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Step 6: Create and train the model
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_val_dataset,
    dataset_text_field="inputs_text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
    data_collator=data_collator,
    packing=False,
)

# Start training
print("Starting training...")
trainer.train()

# Step 7: Save the trained model
os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
trainer.model.save_pretrained(FINAL_MODEL_PATH)
tokenizer.save_pretrained(FINAL_MODEL_PATH)
print(f"Model saved to {FINAL_MODEL_PATH}")

# Step 8: Test the model on a few examples
print("Testing model on examples...")

# Load rouge for evaluation
rouge = evaluate.load('rouge')

# Reload base model for comparison
print("Loading base model for comparison...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map=device_map
)

# Load fine-tuned model (PEFT)
print("Loading fine-tuned model...")
peft_model = PeftModel.from_pretrained(
    base_model,
    FINAL_MODEL_PATH,
    device_map=device_map
)

# Test on examples from test set
test_questions = test_df['Question'][:10].tolist()  # Test on 10 examples
test_answers = test_df['Answer'][:10].tolist()

base_model_outputs = []
peft_model_outputs = []

print("\nGenerating responses from base and fine-tuned models...")
for question in test_questions:
    # Format prompt for the model
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    
    # Base model generation
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    base_outputs = base_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,  # Increased repetition penalty
        no_repeat_ngram_size=3   # Prevent repeating 3-grams
    )
    base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    
    # Clean the base model output
    base_text = base_text.replace(prompt, "").strip()
    base_text = post_process_response(base_text)
    base_model_outputs.append(base_text)
    
    # Fine-tuned model generation
    ft_outputs = peft_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    ft_text = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)
    
    # Clean the fine-tuned model output
    ft_text = ft_text.replace(prompt, "").strip()
    ft_text = post_process_response(ft_text)
    peft_model_outputs.append(ft_text)

# Print results for a few examples
for i, (question, answer, base_output, peft_output) in enumerate(zip(test_questions[:3], test_answers[:3], base_model_outputs[:3], peft_model_outputs[:3])):
    print(f"\n\n--- Example {i+1} ---")
    print(f"Question: {question}")
    print(f"Reference Answer: {answer}")
    print(f"Base Model Output: {base_output}")
    print(f"Fine-tuned Model Output: {peft_output}")

# Calculate ROUGE scores
base_rouge_results = rouge.compute(
    predictions=base_model_outputs,
    references=test_answers[:len(base_model_outputs)],
    use_stemmer=True
)

peft_rouge_results = rouge.compute(
    predictions=peft_model_outputs,
    references=test_answers[:len(peft_model_outputs)],
    use_stemmer=True
)

print("\n--- ROUGE Scores ---")
print("Base Model:")
print(base_rouge_results)
print("\nFine-tuned Model:")
print(peft_rouge_results)

# Save the generated responses for manual inspection
results_df = pd.DataFrame({
    "Question": test_questions,
    "Reference_Answer": test_answers[:len(test_questions)],
    "Base_Model_Output": base_model_outputs,
    "Fine_Tuned_Output": peft_model_outputs
})
results_df.to_csv(RESULTS_CSV, index=False)
print(f"\nSaved comparison results to {RESULTS_CSV}")

# Create a simple inference function to test the model interactively
def query_model(question, model=peft_model):
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    
    outputs = model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 300, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Clean the response
    response = response.replace(prompt, "").strip()
    response = post_process_response(response)
    return response

print("\nTraining and evaluation complete!")
print("\nYou can now use the query_model() function to test your model interactively.")
print("Example: response = query_model('What is the processing time for a green card application?')")

Loading and cleaning data...
Found 28 JSON files
Loaded 425 clean question-answer pairs
Dataset shape: (425, 2)
Sample data:
                                            Question  \
0  fter one year, how do I demonstrate that the n...   
1  Where can I find information about vaccination...   

                                              Answer  
0  International Entrepreneur RuleUnder the Inter...  
1  CDC publishes information about vaccinations i...  
Train size: 340, Validation size: 42, Test size: 43


Saving the dataset (1/1 shards): 100%|██████████| 340/340 [00:00<00:00, 65677.86 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 42/42 [00:00<00:00, 10628.10 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 43/43 [00:00<00:00, 4935.69 examples/s]

Dataset saved to ./immigration_qa_dataset_clean
Loading tokenizer for facebook/opt-1.3b...





Loading model...


  return torch.load(checkpoint_file, map_location="cpu")


Preprocessing datasets...


Map: 100%|██████████| 340/340 [00:00<00:00, 1735.76 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 1554.23 examples/s]
Map: 100%|██████████| 43/43 [00:00<00:00, 630.99 examples/s]


Processed train dataset size: 340
Processed validation dataset size: 42
Processed test dataset size: 43
Preparing model for training...
trainable params: 12,582,912 || all params: 724,361,216 || trainable%: 1.7371045994820353


Map: 100%|██████████| 340/340 [00:00<00:00, 2111.53 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 4109.19 examples/s]


Starting training...


  0%|          | 0/150 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
  2%|▏         | 3/150 [00:40<33:23, 13.63s/it]
  2%|▏         | 3/150 [00:52<33:23, 13.63s/it]

{'eval_loss': 2.8352134227752686, 'eval_runtime': 1.6538, 'eval_samples_per_second': 25.395, 'eval_steps_per_second': 1.209, 'epoch': 0.8}


  return fn(*args, **kwargs)
  5%|▍         | 7/150 [01:40<33:55, 14.24s/it]
  5%|▍         | 7/150 [01:46<33:55, 14.24s/it]

{'eval_loss': 2.737858772277832, 'eval_runtime': 1.7435, 'eval_samples_per_second': 24.09, 'eval_steps_per_second': 1.147, 'epoch': 1.87}


  return fn(*args, **kwargs)
  7%|▋         | 10/150 [02:23<33:01, 14.15s/it]

{'loss': 2.7609, 'learning_rate': 1.9990212265199738e-05, 'epoch': 2.67}


  7%|▋         | 11/150 [02:37<32:07, 13.87s/it]
  7%|▋         | 11/150 [02:39<32:07, 13.87s/it]

{'eval_loss': 2.5971450805664062, 'eval_runtime': 1.5897, 'eval_samples_per_second': 26.419, 'eval_steps_per_second': 1.258, 'epoch': 2.93}


  return fn(*args, **kwargs)
 10%|█         | 15/150 [03:32<30:38, 13.62s/it]
 10%|█         | 15/150 [03:34<30:38, 13.62s/it]

{'eval_loss': 2.477029323577881, 'eval_runtime': 1.5486, 'eval_samples_per_second': 27.121, 'eval_steps_per_second': 1.291, 'epoch': 4.0}


  return fn(*args, **kwargs)
 12%|█▏        | 18/150 [04:16<31:03, 14.12s/it]
 12%|█▏        | 18/150 [04:28<31:03, 14.12s/it]

{'eval_loss': 2.382483720779419, 'eval_runtime': 1.6693, 'eval_samples_per_second': 25.161, 'eval_steps_per_second': 1.198, 'epoch': 4.8}


  return fn(*args, **kwargs)
 13%|█▎        | 20/150 [04:49<32:48, 15.14s/it]

{'loss': 2.5594, 'learning_rate': 1.96496491452281e-05, 'epoch': 5.33}


 15%|█▍        | 22/150 [05:17<31:06, 14.58s/it]
 15%|█▍        | 22/150 [05:22<31:06, 14.58s/it]

{'eval_loss': 2.2463221549987793, 'eval_runtime': 1.569, 'eval_samples_per_second': 26.769, 'eval_steps_per_second': 1.275, 'epoch': 5.87}


  return fn(*args, **kwargs)
 17%|█▋        | 26/150 [06:13<28:34, 13.83s/it]
 17%|█▋        | 26/150 [06:15<28:34, 13.83s/it]

{'eval_loss': 2.1070830821990967, 'eval_runtime': 1.4457, 'eval_samples_per_second': 29.051, 'eval_steps_per_second': 1.383, 'epoch': 6.93}


  return fn(*args, **kwargs)
 20%|██        | 30/150 [07:07<26:38, 13.32s/it]

{'loss': 2.3364, 'learning_rate': 1.883869132745561e-05, 'epoch': 8.0}



 20%|██        | 30/150 [07:09<26:38, 13.32s/it]

{'eval_loss': 1.965039610862732, 'eval_runtime': 1.5724, 'eval_samples_per_second': 26.711, 'eval_steps_per_second': 1.272, 'epoch': 8.0}


  return fn(*args, **kwargs)
 22%|██▏       | 33/150 [07:50<27:01, 13.86s/it]
 22%|██▏       | 33/150 [08:01<27:01, 13.86s/it]

{'eval_loss': 1.8684237003326416, 'eval_runtime': 1.5431, 'eval_samples_per_second': 27.217, 'eval_steps_per_second': 1.296, 'epoch': 8.8}


  return fn(*args, **kwargs)
 25%|██▍       | 37/150 [08:50<27:02, 14.36s/it]
 25%|██▍       | 37/150 [08:55<27:02, 14.36s/it]

{'eval_loss': 1.7902131080627441, 'eval_runtime': 1.5421, 'eval_samples_per_second': 27.236, 'eval_steps_per_second': 1.297, 'epoch': 9.87}


  return fn(*args, **kwargs)
 27%|██▋       | 40/150 [09:32<25:50, 14.10s/it]

{'loss': 2.1243, 'learning_rate': 1.759687084583285e-05, 'epoch': 10.67}


 27%|██▋       | 41/150 [09:45<24:59, 13.76s/it]
 27%|██▋       | 41/150 [09:48<24:59, 13.76s/it]

{'eval_loss': 1.7352055311203003, 'eval_runtime': 1.9197, 'eval_samples_per_second': 21.878, 'eval_steps_per_second': 1.042, 'epoch': 10.93}


  return fn(*args, **kwargs)
 30%|███       | 45/150 [10:40<23:27, 13.41s/it]
 30%|███       | 45/150 [10:42<23:27, 13.41s/it]

{'eval_loss': 1.697068691253662, 'eval_runtime': 1.5776, 'eval_samples_per_second': 26.623, 'eval_steps_per_second': 1.268, 'epoch': 12.0}


  return fn(*args, **kwargs)
 32%|███▏      | 48/150 [11:24<23:45, 13.98s/it]
 32%|███▏      | 48/150 [11:36<23:45, 13.98s/it]

{'eval_loss': 1.6912609338760376, 'eval_runtime': 1.5924, 'eval_samples_per_second': 26.376, 'eval_steps_per_second': 1.256, 'epoch': 12.8}


  return fn(*args, **kwargs)
 33%|███▎      | 50/150 [11:57<25:08, 15.09s/it]

{'loss': 1.9901, 'learning_rate': 1.5984723141740578e-05, 'epoch': 13.33}


 35%|███▍      | 52/150 [12:23<23:03, 14.11s/it]
 35%|███▍      | 52/150 [12:29<23:03, 14.11s/it]

{'eval_loss': 1.6894094944000244, 'eval_runtime': 1.6534, 'eval_samples_per_second': 25.402, 'eval_steps_per_second': 1.21, 'epoch': 13.87}


  return fn(*args, **kwargs)
 37%|███▋      | 56/150 [13:19<21:30, 13.72s/it]
 37%|███▋      | 56/150 [13:22<21:30, 13.72s/it]

{'eval_loss': 1.6972147226333618, 'eval_runtime': 1.6674, 'eval_samples_per_second': 25.189, 'eval_steps_per_second': 1.199, 'epoch': 14.93}


  return fn(*args, **kwargs)
 40%|████      | 60/150 [14:14<20:07, 13.41s/it]

{'loss': 1.9258, 'learning_rate': 1.408083612243465e-05, 'epoch': 16.0}



 40%|████      | 60/150 [14:16<20:07, 13.41s/it]

{'eval_loss': 1.699666142463684, 'eval_runtime': 1.5924, 'eval_samples_per_second': 26.375, 'eval_steps_per_second': 1.256, 'epoch': 16.0}


  return fn(*args, **kwargs)
 42%|████▏     | 63/150 [14:57<20:02, 13.82s/it]
 42%|████▏     | 63/150 [15:09<20:02, 13.82s/it]

{'eval_loss': 1.6935288906097412, 'eval_runtime': 1.5889, 'eval_samples_per_second': 26.433, 'eval_steps_per_second': 1.259, 'epoch': 16.8}


  return fn(*args, **kwargs)
 45%|████▍     | 67/150 [15:58<20:00, 14.46s/it]
 45%|████▍     | 67/150 [16:03<20:00, 14.46s/it]

{'eval_loss': 1.6900336742401123, 'eval_runtime': 1.6039, 'eval_samples_per_second': 26.186, 'eval_steps_per_second': 1.247, 'epoch': 17.87}


  return fn(*args, **kwargs)
 47%|████▋     | 70/150 [16:40<18:50, 14.13s/it]

{'loss': 1.9248, 'learning_rate': 1.1978019209855174e-05, 'epoch': 18.67}


 47%|████▋     | 71/150 [16:53<18:06, 13.75s/it]
 47%|████▋     | 71/150 [16:56<18:06, 13.75s/it]

{'eval_loss': 1.6903163194656372, 'eval_runtime': 1.6168, 'eval_samples_per_second': 25.978, 'eval_steps_per_second': 1.237, 'epoch': 18.93}


  return fn(*args, **kwargs)
 50%|█████     | 75/150 [17:48<16:42, 13.37s/it]
 50%|█████     | 75/150 [17:50<16:42, 13.37s/it]

{'eval_loss': 1.6927034854888916, 'eval_runtime': 1.9952, 'eval_samples_per_second': 21.051, 'eval_steps_per_second': 1.002, 'epoch': 20.0}


  return fn(*args, **kwargs)
 52%|█████▏    | 78/150 [18:31<16:44, 13.95s/it]
 52%|█████▏    | 78/150 [18:43<16:44, 13.95s/it]

{'eval_loss': 1.6937075853347778, 'eval_runtime': 1.6165, 'eval_samples_per_second': 25.982, 'eval_steps_per_second': 1.237, 'epoch': 20.8}


  return fn(*args, **kwargs)
 53%|█████▎    | 80/150 [19:04<17:23, 14.91s/it]

{'loss': 1.8703, 'learning_rate': 9.778779128468133e-06, 'epoch': 21.33}


 55%|█████▍    | 82/150 [19:32<16:11, 14.29s/it]
 55%|█████▍    | 82/150 [19:37<16:11, 14.29s/it]

{'eval_loss': 1.6938203573226929, 'eval_runtime': 1.8243, 'eval_samples_per_second': 23.022, 'eval_steps_per_second': 1.096, 'epoch': 21.87}


  return fn(*args, **kwargs)
 57%|█████▋    | 86/150 [20:27<14:34, 13.67s/it]
 57%|█████▋    | 86/150 [20:29<14:34, 13.67s/it]

{'eval_loss': 1.6933379173278809, 'eval_runtime': 1.559, 'eval_samples_per_second': 26.94, 'eval_steps_per_second': 1.283, 'epoch': 22.93}


  return fn(*args, **kwargs)
 60%|██████    | 90/150 [21:20<13:02, 13.04s/it]

{'loss': 1.8667, 'learning_rate': 7.590322975433857e-06, 'epoch': 24.0}



 60%|██████    | 90/150 [21:22<13:02, 13.04s/it]

{'eval_loss': 1.6936701536178589, 'eval_runtime': 1.5707, 'eval_samples_per_second': 26.739, 'eval_steps_per_second': 1.273, 'epoch': 24.0}


  return fn(*args, **kwargs)
 62%|██████▏   | 93/150 [22:03<12:57, 13.64s/it]
 62%|██████▏   | 93/150 [22:14<12:57, 13.64s/it]

{'eval_loss': 1.6948999166488647, 'eval_runtime': 1.5508, 'eval_samples_per_second': 27.083, 'eval_steps_per_second': 1.29, 'epoch': 24.8}


  return fn(*args, **kwargs)
 65%|██████▍   | 97/150 [23:02<12:31, 14.18s/it]
 65%|██████▍   | 97/150 [23:07<12:31, 14.18s/it]

{'eval_loss': 1.693650722503662, 'eval_runtime': 1.8207, 'eval_samples_per_second': 23.067, 'eval_steps_per_second': 1.098, 'epoch': 25.87}


  return fn(*args, **kwargs)
 67%|██████▋   | 100/150 [23:45<11:42, 14.05s/it]

{'loss': 1.873, 'learning_rate': 5.519332160124215e-06, 'epoch': 26.67}


 67%|██████▋   | 101/150 [23:58<11:16, 13.81s/it]
 67%|██████▋   | 101/150 [24:01<11:16, 13.81s/it]

{'eval_loss': 1.6925392150878906, 'eval_runtime': 1.5616, 'eval_samples_per_second': 26.896, 'eval_steps_per_second': 1.281, 'epoch': 26.93}


  return fn(*args, **kwargs)
 70%|███████   | 105/150 [24:53<10:05, 13.46s/it]
 70%|███████   | 105/150 [24:55<10:05, 13.46s/it]

{'eval_loss': 1.692589521408081, 'eval_runtime': 1.5345, 'eval_samples_per_second': 27.371, 'eval_steps_per_second': 1.303, 'epoch': 28.0}


  return fn(*args, **kwargs)
 72%|███████▏  | 108/150 [25:36<09:44, 13.93s/it]
 72%|███████▏  | 108/150 [25:47<09:44, 13.93s/it]

{'eval_loss': 1.692967414855957, 'eval_runtime': 1.5507, 'eval_samples_per_second': 27.084, 'eval_steps_per_second': 1.29, 'epoch': 28.8}


  return fn(*args, **kwargs)
 73%|███████▎  | 110/150 [26:08<09:51, 14.78s/it]

{'loss': 1.8477, 'learning_rate': 3.6667619695195287e-06, 'epoch': 29.33}


 75%|███████▍  | 112/150 [26:35<09:03, 14.31s/it]
 75%|███████▍  | 112/150 [26:41<09:03, 14.31s/it]

{'eval_loss': 1.6944692134857178, 'eval_runtime': 1.5757, 'eval_samples_per_second': 26.654, 'eval_steps_per_second': 1.269, 'epoch': 29.87}


  return fn(*args, **kwargs)
 77%|███████▋  | 116/150 [27:31<07:47, 13.76s/it]
 77%|███████▋  | 116/150 [27:34<07:47, 13.76s/it]

{'eval_loss': 1.6957935094833374, 'eval_runtime': 1.6097, 'eval_samples_per_second': 26.092, 'eval_steps_per_second': 1.242, 'epoch': 30.93}


  return fn(*args, **kwargs)
 80%|████████  | 120/150 [28:25<06:40, 13.36s/it]

{'loss': 1.8524, 'learning_rate': 2.1229202668228197e-06, 'epoch': 32.0}



 80%|████████  | 120/150 [28:27<06:40, 13.36s/it]

{'eval_loss': 1.6959683895111084, 'eval_runtime': 1.5518, 'eval_samples_per_second': 27.066, 'eval_steps_per_second': 1.289, 'epoch': 32.0}


  return fn(*args, **kwargs)
 82%|████████▏ | 123/150 [29:08<06:11, 13.75s/it]
 82%|████████▏ | 123/150 [29:20<06:11, 13.75s/it]

{'eval_loss': 1.6955112218856812, 'eval_runtime': 1.6069, 'eval_samples_per_second': 26.137, 'eval_steps_per_second': 1.245, 'epoch': 32.8}


  return fn(*args, **kwargs)
 85%|████████▍ | 127/150 [30:09<05:32, 14.45s/it]
 85%|████████▍ | 127/150 [30:13<05:32, 14.45s/it]

{'eval_loss': 1.6950746774673462, 'eval_runtime': 1.5285, 'eval_samples_per_second': 27.478, 'eval_steps_per_second': 1.308, 'epoch': 33.87}


  return fn(*args, **kwargs)
 87%|████████▋ | 130/150 [30:51<04:44, 14.25s/it]

{'loss': 1.8739, 'learning_rate': 9.630652236279626e-07, 'epoch': 34.67}


 87%|████████▋ | 131/150 [31:04<04:19, 13.68s/it]
 87%|████████▋ | 131/150 [31:07<04:19, 13.68s/it]

{'eval_loss': 1.6946264505386353, 'eval_runtime': 1.8569, 'eval_samples_per_second': 22.619, 'eval_steps_per_second': 1.077, 'epoch': 34.93}


  return fn(*args, **kwargs)
 90%|█████████ | 135/150 [31:58<03:20, 13.36s/it]
 90%|█████████ | 135/150 [32:00<03:20, 13.36s/it]

{'eval_loss': 1.6945058107376099, 'eval_runtime': 1.572, 'eval_samples_per_second': 26.718, 'eval_steps_per_second': 1.272, 'epoch': 36.0}


  return fn(*args, **kwargs)
 92%|█████████▏| 138/150 [32:41<02:45, 13.83s/it]
 92%|█████████▏| 138/150 [32:53<02:45, 13.83s/it]

{'eval_loss': 1.6944468021392822, 'eval_runtime': 1.8255, 'eval_samples_per_second': 23.008, 'eval_steps_per_second': 1.096, 'epoch': 36.8}


  return fn(*args, **kwargs)
 93%|█████████▎| 140/150 [33:14<02:28, 14.82s/it]

{'loss': 1.8417, 'learning_rate': 2.4373668447493225e-07, 'epoch': 37.33}


 95%|█████████▍| 142/150 [33:42<01:55, 14.38s/it]
 95%|█████████▍| 142/150 [33:47<01:55, 14.38s/it]

{'eval_loss': 1.6944329738616943, 'eval_runtime': 1.5759, 'eval_samples_per_second': 26.651, 'eval_steps_per_second': 1.269, 'epoch': 37.87}


  return fn(*args, **kwargs)
 97%|█████████▋| 146/150 [34:37<00:54, 13.64s/it]
 97%|█████████▋| 146/150 [34:40<00:54, 13.64s/it]

{'eval_loss': 1.6945277452468872, 'eval_runtime': 1.7325, 'eval_samples_per_second': 24.242, 'eval_steps_per_second': 1.154, 'epoch': 38.93}


  return fn(*args, **kwargs)
100%|██████████| 150/150 [35:32<00:00, 13.43s/it]

{'loss': 1.8467, 'learning_rate': 0.0, 'epoch': 40.0}



100%|██████████| 150/150 [35:33<00:00, 13.43s/it]

{'eval_loss': 1.6945019960403442, 'eval_runtime': 1.5686, 'eval_samples_per_second': 26.776, 'eval_steps_per_second': 1.275, 'epoch': 40.0}


100%|██████████| 150/150 [35:34<00:00, 14.23s/it]


{'train_runtime': 2134.0319, 'train_samples_per_second': 7.966, 'train_steps_per_second': 0.07, 'train_loss': 2.0329349136352537, 'epoch': 40.0}
Model saved to ./immigration_assistant_final
Testing model on examples...
Loading base model for comparison...


  return torch.load(checkpoint_file, map_location="cpu")


Loading fine-tuned model...


  adapters_weights = torch.load(



Generating responses from base and fine-tuned models...


--- Example 1 ---
Question: I received a Notice of Intent to Deny (NOID) my case from the government. What can I do?
Reference Answer: Many times the government improperly concludes that a case is deniable. Our experienced attorneys have successfully resolved cases in which the government intends to deny the case. While results may vary depending upon fact patterns and a case cannot always be resolved, a consultation with an attorney may turn up another avenue of relief.
Base Model Output: If you have not been notified that your application for asylum will be considered, contact our Asylum Assistance Center at 212-621-7000 or email [email protected] with all details regarding your claim and request information on how to file it online by May 25, 2018. Once we receive the requested documents and other supporting documentation in response to your NOID, we may begin processing your petition within 30 days of receipt, but if we fin

In [51]:
import os
import json
import glob
import re
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate

# ==================== CONFIGURATION ====================
# Set paths
BASE_DIRECTORY = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"
DATASET_PATH = "./immigration_qa_dataset_clean"
OUTPUT_DIR = "./immigration_assistant_model_final"
FINAL_MODEL_PATH = "./immigration_assistant_final"
RESULTS_CSV = "./model_comparison_results.csv"

# Open-access model that doesn't require authentication
MODEL_ID = "facebook/opt-1.3b"  # 1.3B parameters, open access

# Training parameters
EPOCHS = 100
BATCH_SIZE = 8
LEARNING_RATE = 2e-5
LORA_RANK = 32
LORA_ALPHA = 64

# ==================== HELPER FUNCTIONS ====================

def clean_text(text):
    """Clean text by removing question/answer prefixes and extra whitespace."""
    # Remove "Q." or "Q#." prefixes from questions
    text = re.sub(r'^Q\.?\s*\d*\.?\s*', '', text)
    # Remove "A." or "A#." prefixes from answers
    text = re.sub(r'^A\.?\s*\d*\.?\s*', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_clean_json_files(directory_path):
    """Load and clean all JSON files in the directory."""
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers or questions
                    if not question or not answer:
                        continue
                    
                    # Clean the texts
                    question = clean_text(question)
                    answer = clean_text(answer)
                    
                    # Skip very short answers (likely not useful)
                    if len(answer) < 20:
                        continue
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} clean question-answer pairs")
    return all_data

def post_process_response(text):
    """Clean model outputs by removing repetitions and known artifacts."""
    # Remove irrelevant prefix text
    if "Question:" in text and "Answer:" in text:
        text = text.split("Answer:", 1)[1].strip()
    
    # Split by lines and remove duplicates while preserving order
    lines = text.split('\n')
    seen_texts = set()
    unique_lines = []
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        # Skip duplicate content
        if line in seen_texts:
            continue
            
        # Skip lines that are question-like
        if line.lower().startswith(("question:", "q:", "what is", "how do", "can i")):
            continue
            
        seen_texts.add(line)
        unique_lines.append(line)
    
    # Join unique lines
    processed_text = '\n'.join(unique_lines)
    
    # If we filtered too much, return the original without duplicates
    if len(processed_text) < 20 and len(text) > 20:
        lines = text.split('\n')
        seen_texts = set()
        unique_lines = []
        for line in lines:
            if line.strip() and line.strip() not in seen_texts:
                seen_texts.add(line.strip())
                unique_lines.append(line)
        processed_text = '\n'.join(unique_lines)
    
    return processed_text

# ==================== MAIN SCRIPT ====================

# Step 1: Load and prepare data
print("Loading and cleaning data...")
all_qa_data = load_and_clean_json_files(BASE_DIRECTORY)

# Convert to DataFrame
df = pd.DataFrame(all_qa_data)
print(f"Dataset shape: {df.shape}")
print("Sample data:")
print(df.head(2))

# Split into train, validation, and test sets (80%, 10%, 10%)
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))

train_df = df[:train_size]
val_df = df[train_size:train_size+val_size]
test_df = df[train_size+val_size:]

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Combine into a dataset dictionary
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

# Save the clean dataset to disk
os.makedirs(DATASET_PATH, exist_ok=True)
dataset_dict.save_to_disk(DATASET_PATH)
print(f"Dataset saved to {DATASET_PATH}")

# Step 2: Load Model and Tokenizer
# Define quantization config for 4-bit precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

# Load tokenizer
print(f"Loading tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with quantization config
print("Loading model...")
device_map = {"": 0}  # Use GPU 0
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    quantization_config=bnb_config, 
    use_cache=False,
    device_map=device_map
)

# Step 3: Define preprocessing function for clean instruction format
def preprocess_function(examples):
    # Use a clear instruction format without complex templates
    formatted_prompts = [
        f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {q}\n\n### Response:" 
        for q in examples["Question"]
    ]
    
    return {
        "input_ids": tokenizer(
            formatted_prompts,
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "labels": tokenizer(
            examples["Answer"], 
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "inputs_text": [f"{prompt} {answer}" for prompt, answer in zip(formatted_prompts, examples["Answer"])],
    }

# Apply preprocessing to datasets
print("Preprocessing datasets...")
processed_train_dataset = dataset_dict['train'].map(preprocess_function, batched=True)
processed_val_dataset = dataset_dict['validation'].map(preprocess_function, batched=True)
processed_test_dataset = dataset_dict['test'].map(preprocess_function, batched=True)

print(f"Processed train dataset size: {len(processed_train_dataset)}")
print(f"Processed validation dataset size: {len(processed_val_dataset)}")
print(f"Processed test dataset size: {len(processed_test_dataset)}")

# Step 4: Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.05,  # Reduced dropout for better learning
    r=LORA_RANK,
    bias="none",
    task_type="CAUSAL_LM",
    # Target projection layers in OPT model
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"].21
)

# Prepare model for kbit training
print("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Step 5: Define training arguments
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_arguments = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    max_grad_norm=0.3,
    warmup_ratio=0.05,
    group_by_length=True,
    lr_scheduler_type="cosine",
    fp16=False,
    bf16=False,
    report_to="none",
    # Add the following to prevent repetition during training
    remove_unused_columns=False,
    label_names=["labels"],
)

# Data collator for language model training
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Step 6: Create and train the model
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_val_dataset,
    dataset_text_field="inputs_text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
    data_collator=data_collator,
    packing=False,
)

# Start training
print("Starting training...")
trainer.train()

# Step 7: Save the trained model
os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
trainer.model.save_pretrained(FINAL_MODEL_PATH)
tokenizer.save_pretrained(FINAL_MODEL_PATH)
print(f"Model saved to {FINAL_MODEL_PATH}")

# Step 8: Test the model on a few examples
print("Testing model on examples...")

# Load rouge for evaluation
rouge = evaluate.load('rouge')

# Reload base model for comparison
print("Loading base model for comparison...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map=device_map
)

# Load fine-tuned model (PEFT)
print("Loading fine-tuned model...")
peft_model = PeftModel.from_pretrained(
    base_model,
    FINAL_MODEL_PATH,
    device_map=device_map
)

# Test on examples from test set
test_questions = test_df['Question'][:10].tolist()  # Test on 10 examples
test_answers = test_df['Answer'][:10].tolist()

base_model_outputs = []
peft_model_outputs = []

print("\nGenerating responses from base and fine-tuned models...")
for question in test_questions:
    # Format prompt for the model
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    
    # Base model generation
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    base_outputs = base_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,  # Increased repetition penalty
        no_repeat_ngram_size=3   # Prevent repeating 3-grams
    )
    base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    
    # Clean the base model output
    base_text = base_text.replace(prompt, "").strip()
    base_text = post_process_response(base_text)
    base_model_outputs.append(base_text)
    
    # Fine-tuned model generation
    ft_outputs = peft_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    ft_text = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)
    
    # Clean the fine-tuned model output
    ft_text = ft_text.replace(prompt, "").strip()
    ft_text = post_process_response(ft_text)
    peft_model_outputs.append(ft_text)

# Print results for a few examples
for i, (question, answer, base_output, peft_output) in enumerate(zip(test_questions[:3], test_answers[:3], base_model_outputs[:3], peft_model_outputs[:3])):
    print(f"\n\n--- Example {i+1} ---")
    print(f"Question: {question}")
    print(f"Reference Answer: {answer}")
    print(f"Base Model Output: {base_output}")
    print(f"Fine-tuned Model Output: {peft_output}")

# Calculate ROUGE scores
base_rouge_results = rouge.compute(
    predictions=base_model_outputs,
    references=test_answers[:len(base_model_outputs)],
    use_stemmer=True
)

peft_rouge_results = rouge.compute(
    predictions=peft_model_outputs,
    references=test_answers[:len(peft_model_outputs)],
    use_stemmer=True
)

print("\n--- ROUGE Scores ---")
print("Base Model:")
print(base_rouge_results)
print("\nFine-tuned Model:")
print(peft_rouge_results)

# Save the generated responses for manual inspection
results_df = pd.DataFrame({
    "Question": test_questions,
    "Reference_Answer": test_answers[:len(test_questions)],
    "Base_Model_Output": base_model_outputs,
    "Fine_Tuned_Output": peft_model_outputs
})
results_df.to_csv(RESULTS_CSV, index=False)
print(f"\nSaved comparison results to {RESULTS_CSV}")

# Create a simple inference function to test the model interactively
def query_model(question, model=peft_model):
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    
    outputs = model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 300, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Clean the response
    response = response.replace(prompt, "").strip()
    response = post_process_response(response)
    return response

print("\nTraining and evaluation complete!")
print("\nYou can now use the query_model() function to test your model interactively.")
print("Example: response = query_model('What is the processing time for a green card application?')")

Loading and cleaning data...
Found 28 JSON files
Loaded 425 clean question-answer pairs
Dataset shape: (425, 2)
Sample data:
                                            Question  \
0  fter one year, how do I demonstrate that the n...   
1  Where can I find information about vaccination...   

                                              Answer  
0  International Entrepreneur RuleUnder the Inter...  
1  CDC publishes information about vaccinations i...  
Train size: 340, Validation size: 42, Test size: 43


Saving the dataset (1/1 shards): 100%|██████████| 340/340 [00:00<00:00, 82317.21 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 42/42 [00:00<00:00, 5430.19 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 43/43 [00:00<00:00, 8869.19 examples/s] 

Dataset saved to ./immigration_qa_dataset_clean
Loading tokenizer for facebook/opt-1.3b...





Loading model...


  return torch.load(checkpoint_file, map_location="cpu")


Preprocessing datasets...


Map: 100%|██████████| 340/340 [00:00<00:00, 1770.37 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 2396.94 examples/s]
Map: 100%|██████████| 43/43 [00:00<00:00, 570.89 examples/s]


Processed train dataset size: 340
Processed validation dataset size: 42
Processed test dataset size: 43
Preparing model for training...
trainable params: 12,582,912 || all params: 724,361,216 || trainable%: 1.7371045994820353


Map: 100%|██████████| 340/340 [00:00<00:00, 2279.22 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 5799.15 examples/s]


Starting training...


  0%|          | 0/200 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
  1%|          | 2/200 [00:34<57:24, 17.40s/it]
  1%|          | 2/200 [00:51<57:24, 17.40s/it]

{'eval_loss': 2.847487688064575, 'eval_runtime': 1.8094, 'eval_samples_per_second': 23.212, 'eval_steps_per_second': 1.105, 'epoch': 0.73}


  return fn(*args, **kwargs)
  2%|▎         | 5/200 [01:36<1:02:20, 19.18s/it]
  2%|▎         | 5/200 [01:43<1:02:20, 19.18s/it]

{'eval_loss': 2.803516149520874, 'eval_runtime': 1.8157, 'eval_samples_per_second': 23.132, 'eval_steps_per_second': 1.102, 'epoch': 1.82}


  return fn(*args, **kwargs)
  4%|▍         | 8/200 [02:31<58:42, 18.35s/it]  
  4%|▍         | 8/200 [02:34<58:42, 18.35s/it]

{'eval_loss': 2.725045680999756, 'eval_runtime': 1.8514, 'eval_samples_per_second': 22.685, 'eval_steps_per_second': 1.08, 'epoch': 2.91}


  return fn(*args, **kwargs)
  5%|▌         | 10/200 [03:08<58:09, 18.37s/it]

{'loss': 2.7629, 'learning_rate': 2e-05, 'epoch': 3.64}


  6%|▌         | 11/200 [03:24<55:43, 17.69s/it]
  6%|▌         | 11/200 [03:26<55:43, 17.69s/it]

{'eval_loss': 2.621805191040039, 'eval_runtime': 1.6517, 'eval_samples_per_second': 25.429, 'eval_steps_per_second': 1.211, 'epoch': 4.0}


  return fn(*args, **kwargs)
  6%|▋         | 13/200 [04:02<56:40, 18.19s/it]
  6%|▋         | 13/200 [04:17<56:40, 18.19s/it]

{'eval_loss': 2.5585107803344727, 'eval_runtime': 1.6885, 'eval_samples_per_second': 24.874, 'eval_steps_per_second': 1.184, 'epoch': 4.73}


  return fn(*args, **kwargs)
  8%|▊         | 16/200 [05:03<58:12, 18.98s/it]  
  8%|▊         | 16/200 [05:09<58:12, 18.98s/it]

{'eval_loss': 2.4672675132751465, 'eval_runtime': 2.0909, 'eval_samples_per_second': 20.087, 'eval_steps_per_second': 0.957, 'epoch': 5.82}


  return fn(*args, **kwargs)
 10%|▉         | 19/200 [05:57<54:56, 18.21s/it]
 10%|▉         | 19/200 [06:01<54:56, 18.21s/it]

{'eval_loss': 2.370513916015625, 'eval_runtime': 1.881, 'eval_samples_per_second': 22.328, 'eval_steps_per_second': 1.063, 'epoch': 6.91}


  return fn(*args, **kwargs)
 10%|█         | 20/200 [06:17<55:51, 18.62s/it]

{'loss': 2.5794, 'learning_rate': 1.9863613034027224e-05, 'epoch': 7.27}


 11%|█         | 22/200 [06:50<52:16, 17.62s/it]
 11%|█         | 22/200 [06:52<52:16, 17.62s/it]

{'eval_loss': 2.2617435455322266, 'eval_runtime': 1.8213, 'eval_samples_per_second': 23.061, 'eval_steps_per_second': 1.098, 'epoch': 8.0}


  return fn(*args, **kwargs)
 12%|█▏        | 24/200 [07:29<53:37, 18.28s/it]
 12%|█▏        | 24/200 [07:44<53:37, 18.28s/it]

{'eval_loss': 2.1968066692352295, 'eval_runtime': 1.7744, 'eval_samples_per_second': 23.671, 'eval_steps_per_second': 1.127, 'epoch': 8.73}


  return fn(*args, **kwargs)
 14%|█▎        | 27/200 [08:30<55:04, 19.10s/it]
 14%|█▎        | 27/200 [08:35<55:04, 19.10s/it]

{'eval_loss': 2.0851800441741943, 'eval_runtime': 1.6468, 'eval_samples_per_second': 25.505, 'eval_steps_per_second': 1.215, 'epoch': 9.82}


  return fn(*args, **kwargs)
 15%|█▌        | 30/200 [09:24<51:56, 18.33s/it]

{'loss': 2.348, 'learning_rate': 1.9458172417006347e-05, 'epoch': 10.91}



 15%|█▌        | 30/200 [09:27<51:56, 18.33s/it]

{'eval_loss': 1.9774829149246216, 'eval_runtime': 1.8293, 'eval_samples_per_second': 22.96, 'eval_steps_per_second': 1.093, 'epoch': 10.91}


  return fn(*args, **kwargs)
 16%|█▋        | 33/200 [10:17<49:11, 17.67s/it]
 16%|█▋        | 33/200 [10:20<49:11, 17.67s/it]

{'eval_loss': 1.8714172840118408, 'eval_runtime': 2.0183, 'eval_samples_per_second': 20.809, 'eval_steps_per_second': 0.991, 'epoch': 12.0}


  return fn(*args, **kwargs)
 18%|█▊        | 35/200 [10:56<50:21, 18.31s/it]
 18%|█▊        | 35/200 [11:12<50:21, 18.31s/it]

{'eval_loss': 1.822710633277893, 'eval_runtime': 1.8142, 'eval_samples_per_second': 23.151, 'eval_steps_per_second': 1.102, 'epoch': 12.73}


  return fn(*args, **kwargs)
 19%|█▉        | 38/200 [11:57<51:48, 19.19s/it]
 19%|█▉        | 38/200 [12:03<51:48, 19.19s/it]

{'eval_loss': 1.7714369297027588, 'eval_runtime': 1.7912, 'eval_samples_per_second': 23.447, 'eval_steps_per_second': 1.117, 'epoch': 13.82}


  return fn(*args, **kwargs)
 20%|██        | 40/200 [12:34<50:18, 18.87s/it]

{'loss': 2.1278, 'learning_rate': 1.879473751206489e-05, 'epoch': 14.55}


 20%|██        | 41/200 [12:52<48:40, 18.37s/it]
 20%|██        | 41/200 [12:55<48:40, 18.37s/it]

{'eval_loss': 1.7284619808197021, 'eval_runtime': 1.821, 'eval_samples_per_second': 23.064, 'eval_steps_per_second': 1.098, 'epoch': 14.91}


  return fn(*args, **kwargs)
 22%|██▏       | 44/200 [13:45<45:53, 17.65s/it]
 22%|██▏       | 44/200 [13:47<45:53, 17.65s/it]

{'eval_loss': 1.6999809741973877, 'eval_runtime': 1.8031, 'eval_samples_per_second': 23.293, 'eval_steps_per_second': 1.109, 'epoch': 16.0}


  return fn(*args, **kwargs)
 23%|██▎       | 46/200 [14:23<46:55, 18.28s/it]
 23%|██▎       | 46/200 [14:38<46:55, 18.28s/it]

{'eval_loss': 1.6949964761734009, 'eval_runtime': 1.8337, 'eval_samples_per_second': 22.904, 'eval_steps_per_second': 1.091, 'epoch': 16.73}


  return fn(*args, **kwargs)
 24%|██▍       | 49/200 [15:24<48:12, 19.15s/it]
 24%|██▍       | 49/200 [15:31<48:12, 19.15s/it]

{'eval_loss': 1.6941070556640625, 'eval_runtime': 1.9789, 'eval_samples_per_second': 21.224, 'eval_steps_per_second': 1.011, 'epoch': 17.82}


  return fn(*args, **kwargs)
 25%|██▌       | 50/200 [15:44<48:31, 19.41s/it]

{'loss': 1.9868, 'learning_rate': 1.789140509396394e-05, 'epoch': 18.18}


 26%|██▌       | 52/200 [16:19<45:11, 18.32s/it]
 26%|██▌       | 52/200 [16:22<45:11, 18.32s/it]

{'eval_loss': 1.7021595239639282, 'eval_runtime': 1.8205, 'eval_samples_per_second': 23.07, 'eval_steps_per_second': 1.099, 'epoch': 18.91}


  return fn(*args, **kwargs)
 28%|██▊       | 55/200 [17:12<42:54, 17.76s/it]
 28%|██▊       | 55/200 [17:14<42:54, 17.76s/it]

{'eval_loss': 1.7045756578445435, 'eval_runtime': 1.8095, 'eval_samples_per_second': 23.21, 'eval_steps_per_second': 1.105, 'epoch': 20.0}


  return fn(*args, **kwargs)
 28%|██▊       | 57/200 [17:51<43:48, 18.38s/it]
 28%|██▊       | 57/200 [18:06<43:48, 18.38s/it]

{'eval_loss': 1.698886513710022, 'eval_runtime': 1.8343, 'eval_samples_per_second': 22.898, 'eval_steps_per_second': 1.09, 'epoch': 20.73}


  return fn(*args, **kwargs)
 30%|███       | 60/200 [18:51<44:20, 19.00s/it]

{'loss': 1.9386, 'learning_rate': 1.6772815716257414e-05, 'epoch': 21.82}



 30%|███       | 60/200 [18:58<44:20, 19.00s/it]

{'eval_loss': 1.6958081722259521, 'eval_runtime': 1.832, 'eval_samples_per_second': 22.926, 'eval_steps_per_second': 1.092, 'epoch': 21.82}


  return fn(*args, **kwargs)
 32%|███▏      | 63/200 [19:46<41:27, 18.16s/it]
 32%|███▏      | 63/200 [19:49<41:27, 18.16s/it]

{'eval_loss': 1.6989738941192627, 'eval_runtime': 2.0615, 'eval_samples_per_second': 20.373, 'eval_steps_per_second': 0.97, 'epoch': 22.91}


  return fn(*args, **kwargs)
 33%|███▎      | 66/200 [20:39<39:28, 17.68s/it]
 33%|███▎      | 66/200 [20:41<39:28, 17.68s/it]

{'eval_loss': 1.697507381439209, 'eval_runtime': 1.681, 'eval_samples_per_second': 24.986, 'eval_steps_per_second': 1.19, 'epoch': 24.0}


  return fn(*args, **kwargs)
 34%|███▍      | 68/200 [21:17<39:45, 18.07s/it]
 34%|███▍      | 68/200 [21:33<39:45, 18.07s/it]

{'eval_loss': 1.6943410634994507, 'eval_runtime': 1.8301, 'eval_samples_per_second': 22.95, 'eval_steps_per_second': 1.093, 'epoch': 24.73}


  return fn(*args, **kwargs)
 35%|███▌      | 70/200 [22:00<42:07, 19.44s/it]

{'loss': 1.9046, 'learning_rate': 1.5469481581224274e-05, 'epoch': 25.45}


 36%|███▌      | 71/200 [22:18<41:19, 19.22s/it]
 36%|███▌      | 71/200 [22:25<41:19, 19.22s/it]

{'eval_loss': 1.6931068897247314, 'eval_runtime': 1.796, 'eval_samples_per_second': 23.386, 'eval_steps_per_second': 1.114, 'epoch': 25.82}


  return fn(*args, **kwargs)
 37%|███▋      | 74/200 [23:13<38:31, 18.34s/it]
 37%|███▋      | 74/200 [23:16<38:31, 18.34s/it]

{'eval_loss': 1.6958011388778687, 'eval_runtime': 1.8568, 'eval_samples_per_second': 22.619, 'eval_steps_per_second': 1.077, 'epoch': 26.91}


  return fn(*args, **kwargs)
 38%|███▊      | 77/200 [24:06<36:15, 17.69s/it]
 38%|███▊      | 77/200 [24:08<36:15, 17.69s/it]

{'eval_loss': 1.7006253004074097, 'eval_runtime': 1.835, 'eval_samples_per_second': 22.888, 'eval_steps_per_second': 1.09, 'epoch': 28.0}


  return fn(*args, **kwargs)
 40%|███▉      | 79/200 [24:44<36:51, 18.27s/it]
 40%|███▉      | 79/200 [25:00<36:51, 18.27s/it]

{'eval_loss': 1.7014676332473755, 'eval_runtime': 1.9267, 'eval_samples_per_second': 21.799, 'eval_steps_per_second': 1.038, 'epoch': 28.73}


  return fn(*args, **kwargs)
 40%|████      | 80/200 [25:10<40:44, 20.37s/it]

{'loss': 1.8825, 'learning_rate': 1.4016954246529697e-05, 'epoch': 29.09}


 41%|████      | 82/200 [25:45<37:27, 19.05s/it]
 41%|████      | 82/200 [25:52<37:27, 19.05s/it]

{'eval_loss': 1.6990200281143188, 'eval_runtime': 1.8196, 'eval_samples_per_second': 23.083, 'eval_steps_per_second': 1.099, 'epoch': 29.82}


  return fn(*args, **kwargs)
 42%|████▎     | 85/200 [26:40<35:12, 18.37s/it]
 42%|████▎     | 85/200 [26:43<35:12, 18.37s/it]

{'eval_loss': 1.6996970176696777, 'eval_runtime': 1.6385, 'eval_samples_per_second': 25.634, 'eval_steps_per_second': 1.221, 'epoch': 30.91}


  return fn(*args, **kwargs)
 44%|████▍     | 88/200 [27:33<32:50, 17.60s/it]
 44%|████▍     | 88/200 [27:35<32:50, 17.60s/it]

{'eval_loss': 1.7004210948944092, 'eval_runtime': 1.8033, 'eval_samples_per_second': 23.29, 'eval_steps_per_second': 1.109, 'epoch': 32.0}


  return fn(*args, **kwargs)
 45%|████▌     | 90/200 [28:10<33:05, 18.05s/it]

{'loss': 1.8613, 'learning_rate': 1.2454854871407993e-05, 'epoch': 32.73}



 45%|████▌     | 90/200 [28:26<33:05, 18.05s/it]

{'eval_loss': 1.7016347646713257, 'eval_runtime': 1.8096, 'eval_samples_per_second': 23.21, 'eval_steps_per_second': 1.105, 'epoch': 32.73}


  return fn(*args, **kwargs)
 46%|████▋     | 93/200 [29:12<33:58, 19.05s/it]
 46%|████▋     | 93/200 [29:18<33:58, 19.05s/it]

{'eval_loss': 1.7014573812484741, 'eval_runtime': 1.8299, 'eval_samples_per_second': 22.951, 'eval_steps_per_second': 1.093, 'epoch': 33.82}


  return fn(*args, **kwargs)
 48%|████▊     | 96/200 [30:06<31:41, 18.28s/it]
 48%|████▊     | 96/200 [30:10<31:41, 18.28s/it]

{'eval_loss': 1.702175498008728, 'eval_runtime': 2.003, 'eval_samples_per_second': 20.969, 'eval_steps_per_second': 0.999, 'epoch': 34.91}


  return fn(*args, **kwargs)
 50%|████▉     | 99/200 [31:00<29:47, 17.70s/it]
 50%|████▉     | 99/200 [31:01<29:47, 17.70s/it]

{'eval_loss': 1.7037568092346191, 'eval_runtime': 1.6814, 'eval_samples_per_second': 24.979, 'eval_steps_per_second': 1.189, 'epoch': 36.0}


  return fn(*args, **kwargs)
 50%|█████     | 100/200 [31:19<30:22, 18.22s/it]

{'loss': 1.8457, 'learning_rate': 1.0825793454723325e-05, 'epoch': 36.36}


 50%|█████     | 101/200 [31:37<29:55, 18.13s/it]
 50%|█████     | 101/200 [31:53<29:55, 18.13s/it]

{'eval_loss': 1.7039384841918945, 'eval_runtime': 1.7989, 'eval_samples_per_second': 23.347, 'eval_steps_per_second': 1.112, 'epoch': 36.73}


  return fn(*args, **kwargs)
 52%|█████▏    | 104/200 [32:38<30:28, 19.05s/it]
 52%|█████▏    | 104/200 [32:45<30:28, 19.05s/it]

{'eval_loss': 1.7037062644958496, 'eval_runtime': 1.789, 'eval_samples_per_second': 23.477, 'eval_steps_per_second': 1.118, 'epoch': 37.82}


  return fn(*args, **kwargs)
 54%|█████▎    | 107/200 [33:33<28:22, 18.31s/it]
 54%|█████▎    | 107/200 [33:37<28:22, 18.31s/it]

{'eval_loss': 1.7044517993927002, 'eval_runtime': 1.7454, 'eval_samples_per_second': 24.064, 'eval_steps_per_second': 1.146, 'epoch': 38.91}


  return fn(*args, **kwargs)
 55%|█████▌    | 110/200 [34:26<26:25, 17.62s/it]

{'loss': 1.8333, 'learning_rate': 9.174206545276678e-06, 'epoch': 40.0}



 55%|█████▌    | 110/200 [34:28<26:25, 17.62s/it]

{'eval_loss': 1.7064180374145508, 'eval_runtime': 1.8356, 'eval_samples_per_second': 22.881, 'eval_steps_per_second': 1.09, 'epoch': 40.0}


  return fn(*args, **kwargs)
 56%|█████▌    | 112/200 [35:04<26:41, 18.20s/it]
 56%|█████▌    | 112/200 [35:20<26:41, 18.20s/it]

{'eval_loss': 1.7077680826187134, 'eval_runtime': 1.9728, 'eval_samples_per_second': 21.289, 'eval_steps_per_second': 1.014, 'epoch': 40.73}


  return fn(*args, **kwargs)
 57%|█████▊    | 115/200 [36:06<27:04, 19.11s/it]
 57%|█████▊    | 115/200 [36:12<27:04, 19.11s/it]

{'eval_loss': 1.708708643913269, 'eval_runtime': 1.8184, 'eval_samples_per_second': 23.097, 'eval_steps_per_second': 1.1, 'epoch': 41.82}


  return fn(*args, **kwargs)
 59%|█████▉    | 118/200 [37:00<25:02, 18.33s/it]
 59%|█████▉    | 118/200 [37:04<25:02, 18.33s/it]

{'eval_loss': 1.7073900699615479, 'eval_runtime': 1.7608, 'eval_samples_per_second': 23.852, 'eval_steps_per_second': 1.136, 'epoch': 42.91}


  return fn(*args, **kwargs)
 60%|██████    | 120/200 [37:38<24:39, 18.50s/it]

{'loss': 1.8337, 'learning_rate': 7.545145128592009e-06, 'epoch': 43.64}


 60%|██████    | 121/200 [37:54<23:16, 17.68s/it]
 60%|██████    | 121/200 [37:55<23:16, 17.68s/it]

{'eval_loss': 1.7064483165740967, 'eval_runtime': 1.8248, 'eval_samples_per_second': 23.016, 'eval_steps_per_second': 1.096, 'epoch': 44.0}


  return fn(*args, **kwargs)
 62%|██████▏   | 123/200 [38:31<23:27, 18.28s/it]
 62%|██████▏   | 123/200 [38:47<23:27, 18.28s/it]

{'eval_loss': 1.7063566446304321, 'eval_runtime': 1.789, 'eval_samples_per_second': 23.477, 'eval_steps_per_second': 1.118, 'epoch': 44.73}


  return fn(*args, **kwargs)
 63%|██████▎   | 126/200 [39:32<23:33, 19.10s/it]
 63%|██████▎   | 126/200 [39:39<23:33, 19.10s/it]

{'eval_loss': 1.7066733837127686, 'eval_runtime': 1.8671, 'eval_samples_per_second': 22.494, 'eval_steps_per_second': 1.071, 'epoch': 45.82}


  return fn(*args, **kwargs)
 64%|██████▍   | 129/200 [40:27<21:41, 18.33s/it]
 64%|██████▍   | 129/200 [40:31<21:41, 18.33s/it]

{'eval_loss': 1.7075181007385254, 'eval_runtime': 2.0817, 'eval_samples_per_second': 20.176, 'eval_steps_per_second': 0.961, 'epoch': 46.91}


  return fn(*args, **kwargs)
 65%|██████▌   | 130/200 [40:46<21:42, 18.61s/it]

{'loss': 1.8194, 'learning_rate': 5.983045753470308e-06, 'epoch': 47.27}


 66%|██████▌   | 132/200 [41:21<20:07, 17.76s/it]
 66%|██████▌   | 132/200 [41:23<20:07, 17.76s/it]

{'eval_loss': 1.7084261178970337, 'eval_runtime': 1.8472, 'eval_samples_per_second': 22.738, 'eval_steps_per_second': 1.083, 'epoch': 48.0}


  return fn(*args, **kwargs)
 67%|██████▋   | 134/200 [41:59<20:07, 18.30s/it]
 67%|██████▋   | 134/200 [42:14<20:07, 18.30s/it]

{'eval_loss': 1.7094347476959229, 'eval_runtime': 1.8103, 'eval_samples_per_second': 23.2, 'eval_steps_per_second': 1.105, 'epoch': 48.73}


  return fn(*args, **kwargs)
 68%|██████▊   | 137/200 [43:00<20:05, 19.13s/it]
 68%|██████▊   | 137/200 [43:06<20:05, 19.13s/it]

{'eval_loss': 1.709874153137207, 'eval_runtime': 1.8356, 'eval_samples_per_second': 22.881, 'eval_steps_per_second': 1.09, 'epoch': 49.82}


  return fn(*args, **kwargs)
 70%|███████   | 140/200 [43:55<18:31, 18.52s/it]

{'loss': 1.8099, 'learning_rate': 4.530518418775734e-06, 'epoch': 50.91}



 70%|███████   | 140/200 [43:58<18:31, 18.52s/it]

{'eval_loss': 1.7104532718658447, 'eval_runtime': 1.814, 'eval_samples_per_second': 23.153, 'eval_steps_per_second': 1.103, 'epoch': 50.91}


  return fn(*args, **kwargs)
 72%|███████▏  | 143/200 [44:49<16:53, 17.77s/it]
 72%|███████▏  | 143/200 [44:50<16:53, 17.77s/it]

{'eval_loss': 1.7105985879898071, 'eval_runtime': 1.692, 'eval_samples_per_second': 24.822, 'eval_steps_per_second': 1.182, 'epoch': 52.0}


  return fn(*args, **kwargs)
 72%|███████▎  | 145/200 [45:26<16:40, 18.18s/it]
 72%|███████▎  | 145/200 [45:42<16:40, 18.18s/it]

{'eval_loss': 1.7105295658111572, 'eval_runtime': 2.0492, 'eval_samples_per_second': 20.496, 'eval_steps_per_second': 0.976, 'epoch': 52.73}


  return fn(*args, **kwargs)
 74%|███████▍  | 148/200 [46:27<16:26, 18.97s/it]
 74%|███████▍  | 148/200 [46:33<16:26, 18.97s/it]

{'eval_loss': 1.7105573415756226, 'eval_runtime': 1.7892, 'eval_samples_per_second': 23.474, 'eval_steps_per_second': 1.118, 'epoch': 53.82}


  return fn(*args, **kwargs)
 75%|███████▌  | 150/200 [47:05<15:51, 19.02s/it]

{'loss': 1.8121, 'learning_rate': 3.2271842837425917e-06, 'epoch': 54.55}


 76%|███████▌  | 151/200 [47:22<14:58, 18.34s/it]
 76%|███████▌  | 151/200 [47:26<14:58, 18.34s/it]

{'eval_loss': 1.7104713916778564, 'eval_runtime': 1.8058, 'eval_samples_per_second': 23.258, 'eval_steps_per_second': 1.108, 'epoch': 54.91}


  return fn(*args, **kwargs)
 77%|███████▋  | 154/200 [48:15<13:28, 17.58s/it]
 77%|███████▋  | 154/200 [48:17<13:28, 17.58s/it]

{'eval_loss': 1.7108784914016724, 'eval_runtime': 1.8397, 'eval_samples_per_second': 22.83, 'eval_steps_per_second': 1.087, 'epoch': 56.0}


  return fn(*args, **kwargs)
 78%|███████▊  | 156/200 [48:53<13:19, 18.17s/it]
 78%|███████▊  | 156/200 [49:08<13:19, 18.17s/it]

{'eval_loss': 1.711111307144165, 'eval_runtime': 1.6997, 'eval_samples_per_second': 24.711, 'eval_steps_per_second': 1.177, 'epoch': 56.73}


  return fn(*args, **kwargs)
 80%|███████▉  | 159/200 [49:54<13:00, 19.03s/it]
 80%|███████▉  | 159/200 [50:00<13:00, 19.03s/it]

{'eval_loss': 1.7113313674926758, 'eval_runtime': 1.7017, 'eval_samples_per_second': 24.681, 'eval_steps_per_second': 1.175, 'epoch': 57.82}


  return fn(*args, **kwargs)
 80%|████████  | 160/200 [50:14<12:53, 19.34s/it]

{'loss': 1.7976, 'learning_rate': 2.1085949060360654e-06, 'epoch': 58.18}


 81%|████████  | 162/200 [50:49<11:35, 18.32s/it]
 81%|████████  | 162/200 [50:52<11:35, 18.32s/it]

{'eval_loss': 1.7113996744155884, 'eval_runtime': 1.9334, 'eval_samples_per_second': 21.724, 'eval_steps_per_second': 1.034, 'epoch': 58.91}


  return fn(*args, **kwargs)
 82%|████████▎ | 165/200 [51:42<10:21, 17.75s/it]
 82%|████████▎ | 165/200 [51:44<10:21, 17.75s/it]

{'eval_loss': 1.7113734483718872, 'eval_runtime': 1.8222, 'eval_samples_per_second': 23.048, 'eval_steps_per_second': 1.098, 'epoch': 60.0}


  return fn(*args, **kwargs)
 84%|████████▎ | 167/200 [52:20<09:59, 18.18s/it]
 84%|████████▎ | 167/200 [52:36<09:59, 18.18s/it]

{'eval_loss': 1.7115631103515625, 'eval_runtime': 1.7625, 'eval_samples_per_second': 23.829, 'eval_steps_per_second': 1.135, 'epoch': 60.73}


  return fn(*args, **kwargs)
 85%|████████▌ | 170/200 [53:22<09:38, 19.29s/it]

{'loss': 1.7975, 'learning_rate': 1.2052624879351105e-06, 'epoch': 61.82}



 85%|████████▌ | 170/200 [53:28<09:38, 19.29s/it]

{'eval_loss': 1.7114768028259277, 'eval_runtime': 1.8068, 'eval_samples_per_second': 23.245, 'eval_steps_per_second': 1.107, 'epoch': 61.82}


  return fn(*args, **kwargs)
 86%|████████▋ | 173/200 [54:17<08:17, 18.41s/it]
 86%|████████▋ | 173/200 [54:20<08:17, 18.41s/it]

{'eval_loss': 1.7115662097930908, 'eval_runtime': 1.8128, 'eval_samples_per_second': 23.168, 'eval_steps_per_second': 1.103, 'epoch': 62.91}


  return fn(*args, **kwargs)
 88%|████████▊ | 176/200 [55:10<07:04, 17.69s/it]
 88%|████████▊ | 176/200 [55:12<07:04, 17.69s/it]

{'eval_loss': 1.7114976644515991, 'eval_runtime': 1.8216, 'eval_samples_per_second': 23.057, 'eval_steps_per_second': 1.098, 'epoch': 64.0}


  return fn(*args, **kwargs)
 89%|████████▉ | 178/200 [55:48<06:39, 18.17s/it]
 89%|████████▉ | 178/200 [56:04<06:39, 18.17s/it]

{'eval_loss': 1.7113865613937378, 'eval_runtime': 2.0992, 'eval_samples_per_second': 20.008, 'eval_steps_per_second': 0.953, 'epoch': 64.73}


  return fn(*args, **kwargs)
 90%|█████████ | 180/200 [56:32<06:36, 19.82s/it]

{'loss': 1.8005, 'learning_rate': 5.418275829936537e-07, 'epoch': 65.45}


 90%|█████████ | 181/200 [56:49<06:03, 19.11s/it]
 90%|█████████ | 181/200 [56:55<06:03, 19.11s/it]

{'eval_loss': 1.711418867111206, 'eval_runtime': 1.8233, 'eval_samples_per_second': 23.035, 'eval_steps_per_second': 1.097, 'epoch': 65.82}


  return fn(*args, **kwargs)
 92%|█████████▏| 184/200 [57:44<04:52, 18.31s/it]
 92%|█████████▏| 184/200 [57:47<04:52, 18.31s/it]

{'eval_loss': 1.7115107774734497, 'eval_runtime': 1.7834, 'eval_samples_per_second': 23.551, 'eval_steps_per_second': 1.121, 'epoch': 66.91}


  return fn(*args, **kwargs)
 94%|█████████▎| 187/200 [58:37<03:51, 17.80s/it]
 94%|█████████▎| 187/200 [58:39<03:51, 17.80s/it]

{'eval_loss': 1.7115285396575928, 'eval_runtime': 1.8036, 'eval_samples_per_second': 23.287, 'eval_steps_per_second': 1.109, 'epoch': 68.0}


  return fn(*args, **kwargs)
 94%|█████████▍| 189/200 [59:16<03:22, 18.39s/it]
 94%|█████████▍| 189/200 [59:31<03:22, 18.39s/it]

{'eval_loss': 1.7113776206970215, 'eval_runtime': 1.826, 'eval_samples_per_second': 23.001, 'eval_steps_per_second': 1.095, 'epoch': 68.73}


  return fn(*args, **kwargs)
 95%|█████████▌| 190/200 [59:41<03:24, 20.41s/it]

{'loss': 1.7932, 'learning_rate': 1.3638696597277678e-07, 'epoch': 69.09}


 96%|█████████▌| 192/200 [1:00:16<02:31, 18.99s/it]
 96%|█████████▌| 192/200 [1:00:23<02:31, 18.99s/it]

{'eval_loss': 1.7114136219024658, 'eval_runtime': 1.8098, 'eval_samples_per_second': 23.206, 'eval_steps_per_second': 1.105, 'epoch': 69.82}


  return fn(*args, **kwargs)
 98%|█████████▊| 195/200 [1:01:11<01:31, 18.31s/it]
 98%|█████████▊| 195/200 [1:01:15<01:31, 18.31s/it]

{'eval_loss': 1.7114267349243164, 'eval_runtime': 2.0817, 'eval_samples_per_second': 20.176, 'eval_steps_per_second': 0.961, 'epoch': 70.91}


  return fn(*args, **kwargs)
 99%|█████████▉| 198/200 [1:02:05<00:35, 17.81s/it]
 99%|█████████▉| 198/200 [1:02:07<00:35, 17.81s/it]

{'eval_loss': 1.7114038467407227, 'eval_runtime': 1.8033, 'eval_samples_per_second': 23.29, 'eval_steps_per_second': 1.109, 'epoch': 72.0}


  return fn(*args, **kwargs)
100%|██████████| 200/200 [1:02:42<00:00, 18.21s/it]

{'loss': 1.7955, 'learning_rate': 0.0, 'epoch': 72.73}



100%|██████████| 200/200 [1:02:44<00:00, 18.21s/it]

{'eval_loss': 1.7115198373794556, 'eval_runtime': 2.0267, 'eval_samples_per_second': 20.724, 'eval_steps_per_second': 0.987, 'epoch': 72.73}


100%|██████████| 200/200 [1:02:45<00:00, 18.83s/it]


{'train_runtime': 3765.2146, 'train_samples_per_second': 9.03, 'train_steps_per_second': 0.053, 'train_loss': 1.9665108394622803, 'epoch': 72.73}
Model saved to ./immigration_assistant_final
Testing model on examples...
Loading base model for comparison...


  return torch.load(checkpoint_file, map_location="cpu")


Loading fine-tuned model...


  adapters_weights = torch.load(



Generating responses from base and fine-tuned models...


--- Example 1 ---
Question: I received a Notice of Intent to Deny (NOID) my case from the government. What can I do?
Reference Answer: Many times the government improperly concludes that a case is deniable. Our experienced attorneys have successfully resolved cases in which the government intends to deny the case. While results may vary depending upon fact patterns and a case cannot always be resolved, a consultation with an attorney may turn up another avenue of relief.
Base Model Output: If you receive a NOID, but your USCIS case is still pending, we will not consider it as part of our review process at that time and may continue processing other cases in your group for which there has been no adjudication yet. We cannot give you specific guidance on what type of action you should take if you have a Noid; however, generally speaking, any denial or delay is grounds for re-evaluation by us when your case moves forward with adju

In [52]:
response = query_model('What is the processing time for a green card application?')

In [1]:
import os
import json
import glob
import re
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate

# ==================== CONFIGURATION ====================
# Set paths
BASE_DIRECTORY = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"
DATASET_PATH = "./immigration_qa_dataset_clean"
OUTPUT_DIR = "./immigration_assistant_model_final"
FINAL_MODEL_PATH = "./immigration_assistant_final"
RESULTS_CSV = "./model_comparison_results.csv"

# Open-access model that doesn't require authentication
MODEL_ID = "facebook/opt-1.3b"  # 1.3B parameters, open access

# Training parameters
EPOCHS = 100
BATCH_SIZE = 6
LEARNING_RATE = 2e-5
LORA_RANK = 32
LORA_ALPHA = 64

# ==================== HELPER FUNCTIONS ====================

def clean_text(text):
    """Clean text by removing question/answer prefixes and extra whitespace."""
    # Remove "Q." or "Q#." prefixes from questions
    text = re.sub(r'^Q\.?\s*\d*\.?\s*', '', text)
    # Remove "A." or "A#." prefixes from answers
    text = re.sub(r'^A\.?\s*\d*\.?\s*', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_clean_json_files(directory_path):
    """Load and clean all JSON files in the directory."""
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers or questions
                    if not question or not answer:
                        continue
                    
                    # Clean the texts
                    question = clean_text(question)
                    answer = clean_text(answer)
                    
                    # Skip very short answers (likely not useful)
                    if len(answer) < 20:
                        continue
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} clean question-answer pairs")
    return all_data

def post_process_response(text):
    """Clean model outputs by removing repetitions and known artifacts."""
    # Remove irrelevant prefix text
    if "Question:" in text and "Answer:" in text:
        text = text.split("Answer:", 1)[1].strip()
    
    # Split by lines and remove duplicates while preserving order
    lines = text.split('\n')
    seen_texts = set()
    unique_lines = []
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        # Skip duplicate content
        if line in seen_texts:
            continue
            
        # Skip lines that are question-like
        if line.lower().startswith(("question:", "q:", "what is", "how do", "can i")):
            continue
            
        seen_texts.add(line)
        unique_lines.append(line)
    
    # Join unique lines
    processed_text = '\n'.join(unique_lines)
    
    # If we filtered too much, return the original without duplicates
    if len(processed_text) < 20 and len(text) > 20:
        lines = text.split('\n')
        seen_texts = set()
        unique_lines = []
        for line in lines:
            if line.strip() and line.strip() not in seen_texts:
                seen_texts.add(line.strip())
                unique_lines.append(line)
        processed_text = '\n'.join(unique_lines)
    
    return processed_text

# ==================== MAIN SCRIPT ====================

# Step 1: Load and prepare data
print("Loading and cleaning data...")
all_qa_data = load_and_clean_json_files(BASE_DIRECTORY)

# Convert to DataFrame
df = pd.DataFrame(all_qa_data)
print(f"Dataset shape: {df.shape}")
print("Sample data:")
print(df.head(2))

# Split into train, validation, and test sets (80%, 10%, 10%)
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))

train_df = df[:train_size]
val_df = df[train_size:train_size+val_size]
test_df = df[train_size+val_size:]

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Combine into a dataset dictionary
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

# Save the clean dataset to disk
os.makedirs(DATASET_PATH, exist_ok=True)
dataset_dict.save_to_disk(DATASET_PATH)
print(f"Dataset saved to {DATASET_PATH}")

# Step 2: Load Model and Tokenizer
# Define quantization config for 4-bit precision
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

# Load tokenizer
print(f"Loading tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load model with quantization config
print("Loading model...")
device_map = {"": 0}  # Use GPU 0
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, 
    quantization_config=bnb_config, 
    use_cache=False,
    device_map=device_map
)

# Step 3: Define preprocessing function for clean instruction format
def preprocess_function(examples):
    # Use a clear instruction format without complex templates
    formatted_prompts = [
        f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {q}\n\n### Response:" 
        for q in examples["Question"]
    ]
    
    return {
        "input_ids": tokenizer(
            formatted_prompts,
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "labels": tokenizer(
            examples["Answer"], 
            truncation=True, 
            max_length=512, 
            padding="max_length"
        )["input_ids"],
        "inputs_text": [f"{prompt} {answer}" for prompt, answer in zip(formatted_prompts, examples["Answer"])],
    }

# Apply preprocessing to datasets
print("Preprocessing datasets...")
processed_train_dataset = dataset_dict['train'].map(preprocess_function, batched=True)
processed_val_dataset = dataset_dict['validation'].map(preprocess_function, batched=True)
processed_test_dataset = dataset_dict['test'].map(preprocess_function, batched=True)

print(f"Processed train dataset size: {len(processed_train_dataset)}")
print(f"Processed validation dataset size: {len(processed_val_dataset)}")
print(f"Processed test dataset size: {len(processed_test_dataset)}")

# Step 4: Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.02,  # Reduced dropout for better learning
    r=LORA_RANK,
    bias="none",
    task_type="CAUSAL_LM",
    # Target projection layers in OPT model
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
)

# Prepare model for kbit training
print("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Step 5: Define training arguments
os.makedirs(OUTPUT_DIR, exist_ok=True)

training_arguments = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    max_grad_norm=0.3,
    warmup_ratio=0.05,
    group_by_length=True,
    lr_scheduler_type="cosine",
    fp16=False,
    bf16=False,
    report_to="none",
    # Add the following to prevent repetition during training
    remove_unused_columns=False,
    label_names=["labels"],
)

# Data collator for language model training
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Step 6: Create and train the model
trainer = SFTTrainer(
    model=model,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_val_dataset,
    dataset_text_field="inputs_text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
    data_collator=data_collator,
    packing=False,
)

# Start training
print("Starting training...")
trainer.train()

# Step 7: Save the trained model
os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
trainer.model.save_pretrained(FINAL_MODEL_PATH)
tokenizer.save_pretrained(FINAL_MODEL_PATH)
print(f"Model saved to {FINAL_MODEL_PATH}")

# Step 8: Test the model on a few examples
print("Testing model on examples...")

# Load rouge for evaluation
rouge = evaluate.load('rouge')

# Reload base model for comparison
print("Loading base model for comparison...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map=device_map
)

# Load fine-tuned model (PEFT)
print("Loading fine-tuned model...")
peft_model = PeftModel.from_pretrained(
    base_model,
    FINAL_MODEL_PATH,
    device_map=device_map
)

# Test on examples from test set
test_questions = test_df['Question'][:10].tolist()  # Test on 10 examples
test_answers = test_df['Answer'][:10].tolist()

base_model_outputs = []
peft_model_outputs = []

print("\nGenerating responses from base and fine-tuned models...")
for question in test_questions:
    # Format prompt for the model
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    
    # Base model generation
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    base_outputs = base_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,  # Increased repetition penalty
        no_repeat_ngram_size=3   # Prevent repeating 3-grams
    )
    base_text = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
    
    # Clean the base model output
    base_text = base_text.replace(prompt, "").strip()
    base_text = post_process_response(base_text)
    base_model_outputs.append(base_text)
    
    # Fine-tuned model generation
    ft_outputs = peft_model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 250, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    ft_text = tokenizer.decode(ft_outputs[0], skip_special_tokens=True)
    
    # Clean the fine-tuned model output
    ft_text = ft_text.replace(prompt, "").strip()
    ft_text = post_process_response(ft_text)
    peft_model_outputs.append(ft_text)

# Print results for a few examples
for i, (question, answer, base_output, peft_output) in enumerate(zip(test_questions[:3], test_answers[:3], base_model_outputs[:3], peft_model_outputs[:3])):
    print(f"\n\n--- Example {i+1} ---")
    print(f"Question: {question}")
    print(f"Reference Answer: {answer}")
    print(f"Base Model Output: {base_output}")
    print(f"Fine-tuned Model Output: {peft_output}")

# Calculate ROUGE scores
base_rouge_results = rouge.compute(
    predictions=base_model_outputs,
    references=test_answers[:len(base_model_outputs)],
    use_stemmer=True
)

peft_rouge_results = rouge.compute(
    predictions=peft_model_outputs,
    references=test_answers[:len(peft_model_outputs)],
    use_stemmer=True
)

print("\n--- ROUGE Scores ---")
print("Base Model:")
print(base_rouge_results)
print("\nFine-tuned Model:")
print(peft_rouge_results)

# Save the generated responses for manual inspection
results_df = pd.DataFrame({
    "Question": test_questions,
    "Reference_Answer": test_answers[:len(test_questions)],
    "Base_Model_Output": base_model_outputs,
    "Fine_Tuned_Output": peft_model_outputs
})
results_df.to_csv(RESULTS_CSV, index=False)
print(f"\nSaved comparison results to {RESULTS_CSV}")

# Create a simple inference function to test the model interactively
def query_model(question, model=peft_model):
    prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
    
    outputs = model.generate(
        input_ids=input_ids, 
        max_length=len(input_ids[0]) + 300, 
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Clean the response
    response = response.replace(prompt, "").strip()
    response = post_process_response(response)
    return response

print("\nTraining and evaluation complete!")
print("\nYou can now use the query_model() function to test your model interactively.")
print("Example: response = query_model('What is the processing time for a green card application?')")

  from .autonotebook import tqdm as notebook_tqdm


Loading and cleaning data...
Found 28 JSON files
Loaded 425 clean question-answer pairs
Dataset shape: (425, 2)
Sample data:
                                            Question  \
0  fter one year, how do I demonstrate that the n...   
1  Where can I find information about vaccination...   

                                              Answer  
0  International Entrepreneur RuleUnder the Inter...  
1  CDC publishes information about vaccinations i...  
Train size: 340, Validation size: 42, Test size: 43


Saving the dataset (1/1 shards): 100%|██████████| 340/340 [00:00<00:00, 37946.39 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 42/42 [00:00<00:00, 6009.65 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 43/43 [00:00<00:00, 5365.95 examples/s]

Dataset saved to ./immigration_qa_dataset_clean
Loading tokenizer for facebook/opt-1.3b...





Loading model...


  return torch.load(checkpoint_file, map_location="cpu")


Preprocessing datasets...


Map: 100%|██████████| 340/340 [00:00<00:00, 1562.34 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 1770.64 examples/s]
Map: 100%|██████████| 43/43 [00:00<00:00, 551.27 examples/s]


Processed train dataset size: 340
Processed validation dataset size: 42
Processed test dataset size: 43
Preparing model for training...
trainable params: 12,582,912 || all params: 724,361,216 || trainable%: 1.7371045994820353


Map: 100%|██████████| 340/340 [00:00<00:00, 2155.28 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 5897.58 examples/s]


Starting training...


  0%|          | 0/300 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
                                                 
  1%|          | 3/300 [00:53<1:07:38, 13.66s/it]

{'eval_loss': 2.843508005142212, 'eval_runtime': 1.4971, 'eval_samples_per_second': 28.055, 'eval_steps_per_second': 1.336, 'epoch': 0.8}


  return fn(*args, **kwargs)
                                                 
  2%|▏         | 7/300 [01:46<1:09:45, 14.28s/it]

{'eval_loss': 2.7894909381866455, 'eval_runtime': 1.5816, 'eval_samples_per_second': 26.556, 'eval_steps_per_second': 1.265, 'epoch': 1.87}


  return fn(*args, **kwargs)
  3%|▎         | 10/300 [02:24<1:07:48, 14.03s/it]

{'loss': 2.7819, 'learning_rate': 1.3333333333333333e-05, 'epoch': 2.67}


                                                  
  4%|▎         | 11/300 [02:39<1:06:20, 13.77s/it]

{'eval_loss': 2.6966001987457275, 'eval_runtime': 1.5994, 'eval_samples_per_second': 26.26, 'eval_steps_per_second': 1.25, 'epoch': 2.93}


  return fn(*args, **kwargs)
                                                  
  5%|▌         | 15/300 [03:34<1:04:40, 13.61s/it]

{'eval_loss': 2.579209566116333, 'eval_runtime': 1.5599, 'eval_samples_per_second': 26.925, 'eval_steps_per_second': 1.282, 'epoch': 4.0}


  return fn(*args, **kwargs)
                                                  
  6%|▌         | 18/300 [04:28<1:05:54, 14.02s/it]

{'eval_loss': 2.4874427318573, 'eval_runtime': 1.7519, 'eval_samples_per_second': 23.973, 'eval_steps_per_second': 1.142, 'epoch': 4.8}


  return fn(*args, **kwargs)
  7%|▋         | 20/300 [04:49<1:10:26, 15.10s/it]

{'loss': 2.6335, 'learning_rate': 1.9984815164333163e-05, 'epoch': 5.33}


                                                  
  7%|▋         | 22/300 [05:22<1:07:17, 14.52s/it]

{'eval_loss': 2.3583312034606934, 'eval_runtime': 1.5174, 'eval_samples_per_second': 27.68, 'eval_steps_per_second': 1.318, 'epoch': 5.87}


  return fn(*args, **kwargs)
                                                  
  9%|▊         | 26/300 [06:15<1:03:08, 13.83s/it]

{'eval_loss': 2.216672420501709, 'eval_runtime': 1.6099, 'eval_samples_per_second': 26.089, 'eval_steps_per_second': 1.242, 'epoch': 6.93}


  return fn(*args, **kwargs)
 10%|█         | 30/300 [07:07<1:00:36, 13.47s/it]

{'loss': 2.4104, 'learning_rate': 1.9863613034027224e-05, 'epoch': 8.0}


                                                  
 10%|█         | 30/300 [07:09<1:00:36, 13.47s/it]

{'eval_loss': 2.0713632106781006, 'eval_runtime': 1.6171, 'eval_samples_per_second': 25.973, 'eval_steps_per_second': 1.237, 'epoch': 8.0}


  return fn(*args, **kwargs)
                                                  
 11%|█         | 33/300 [08:03<1:02:33, 14.06s/it]

{'eval_loss': 1.9623329639434814, 'eval_runtime': 1.5517, 'eval_samples_per_second': 27.067, 'eval_steps_per_second': 1.289, 'epoch': 8.8}


  return fn(*args, **kwargs)
                                                  
 12%|█▏        | 37/300 [08:56<1:03:24, 14.47s/it]

{'eval_loss': 1.8345930576324463, 'eval_runtime': 1.6719, 'eval_samples_per_second': 25.12, 'eval_steps_per_second': 1.196, 'epoch': 9.87}


  return fn(*args, **kwargs)
 13%|█▎        | 40/300 [09:34<1:01:27, 14.18s/it]

{'loss': 2.1742, 'learning_rate': 1.9622680003092503e-05, 'epoch': 10.67}


                                                  
 14%|█▎        | 41/300 [09:50<59:43, 13.83s/it]

{'eval_loss': 1.7629668712615967, 'eval_runtime': 1.6526, 'eval_samples_per_second': 25.415, 'eval_steps_per_second': 1.21, 'epoch': 10.93}


  return fn(*args, **kwargs)
                                                  
 15%|█▌        | 45/300 [10:44<57:08, 13.44s/it]

{'eval_loss': 1.7088862657546997, 'eval_runtime': 1.6832, 'eval_samples_per_second': 24.953, 'eval_steps_per_second': 1.188, 'epoch': 12.0}


  return fn(*args, **kwargs)
                                                
 16%|█▌        | 48/300 [11:37<58:44, 13.99s/it]

{'eval_loss': 1.6935629844665527, 'eval_runtime': 1.8589, 'eval_samples_per_second': 22.594, 'eval_steps_per_second': 1.076, 'epoch': 12.8}


  return fn(*args, **kwargs)
 17%|█▋        | 50/300 [11:59<1:02:50, 15.08s/it]

{'loss': 2.006, 'learning_rate': 1.9264940672148018e-05, 'epoch': 13.33}


                                                  
 17%|█▋        | 52/300 [12:30<58:14, 14.09s/it]

{'eval_loss': 1.6903311014175415, 'eval_runtime': 1.8338, 'eval_samples_per_second': 22.903, 'eval_steps_per_second': 1.091, 'epoch': 13.87}


  return fn(*args, **kwargs)
                                                
 19%|█▊        | 56/300 [13:24<55:46, 13.72s/it]

{'eval_loss': 1.6974174976348877, 'eval_runtime': 1.6583, 'eval_samples_per_second': 25.327, 'eval_steps_per_second': 1.206, 'epoch': 14.93}


  return fn(*args, **kwargs)
 20%|██        | 60/300 [14:16<53:48, 13.45s/it]

{'loss': 1.9274, 'learning_rate': 1.879473751206489e-05, 'epoch': 16.0}


                                                
 20%|██        | 60/300 [14:18<53:48, 13.45s/it]

{'eval_loss': 1.7003488540649414, 'eval_runtime': 1.8852, 'eval_samples_per_second': 22.279, 'eval_steps_per_second': 1.061, 'epoch': 16.0}


  return fn(*args, **kwargs)
                                                
 21%|██        | 63/300 [15:11<54:53, 13.90s/it]

{'eval_loss': 1.693541169166565, 'eval_runtime': 1.7625, 'eval_samples_per_second': 23.83, 'eval_steps_per_second': 1.135, 'epoch': 16.8}


  return fn(*args, **kwargs)
                                                  
 22%|██▏       | 67/300 [16:06<56:37, 14.58s/it]

{'eval_loss': 1.6909723281860352, 'eval_runtime': 1.6607, 'eval_samples_per_second': 25.29, 'eval_steps_per_second': 1.204, 'epoch': 17.87}


  return fn(*args, **kwargs)
 23%|██▎       | 70/300 [16:44<54:55, 14.33s/it]

{'loss': 1.9174, 'learning_rate': 1.821777815225245e-05, 'epoch': 18.67}


                                                
 24%|██▎       | 71/300 [17:00<53:11, 13.94s/it]

{'eval_loss': 1.6928735971450806, 'eval_runtime': 1.6598, 'eval_samples_per_second': 25.304, 'eval_steps_per_second': 1.205, 'epoch': 18.93}


  return fn(*args, **kwargs)
                                                
 25%|██▌       | 75/300 [17:53<50:23, 13.44s/it]

{'eval_loss': 1.695127010345459, 'eval_runtime': 1.6812, 'eval_samples_per_second': 24.983, 'eval_steps_per_second': 1.19, 'epoch': 20.0}


  return fn(*args, **kwargs)
                                                
 26%|██▌       | 78/300 [18:47<51:33, 13.94s/it]

{'eval_loss': 1.6963386535644531, 'eval_runtime': 1.5395, 'eval_samples_per_second': 27.282, 'eval_steps_per_second': 1.299, 'epoch': 20.8}


  return fn(*args, **kwargs)
 27%|██▋       | 80/300 [19:08<54:50, 14.96s/it]

{'loss': 1.8595, 'learning_rate': 1.7541066097768965e-05, 'epoch': 21.33}


                                                
 27%|██▋       | 82/300 [19:40<52:18, 14.40s/it]

{'eval_loss': 1.696678876876831, 'eval_runtime': 1.3942, 'eval_samples_per_second': 30.125, 'eval_steps_per_second': 1.435, 'epoch': 21.87}


  return fn(*args, **kwargs)
                                                
 29%|██▊       | 86/300 [20:34<49:11, 13.79s/it]

{'eval_loss': 1.6979856491088867, 'eval_runtime': 1.6065, 'eval_samples_per_second': 26.143, 'eval_steps_per_second': 1.245, 'epoch': 22.93}


  return fn(*args, **kwargs)
 30%|███       | 90/300 [21:25<46:04, 13.16s/it]

{'loss': 1.848, 'learning_rate': 1.6772815716257414e-05, 'epoch': 24.0}


                                                
 30%|███       | 90/300 [21:26<46:04, 13.16s/it]

{'eval_loss': 1.7007620334625244, 'eval_runtime': 1.3999, 'eval_samples_per_second': 30.002, 'eval_steps_per_second': 1.429, 'epoch': 24.0}


  return fn(*args, **kwargs)
                                                
 31%|███       | 93/300 [22:20<47:18, 13.71s/it]

{'eval_loss': 1.6997026205062866, 'eval_runtime': 1.5671, 'eval_samples_per_second': 26.8, 'eval_steps_per_second': 1.276, 'epoch': 24.8}


  return fn(*args, **kwargs)
                                                
 32%|███▏      | 97/300 [23:13<48:17, 14.27s/it]

{'eval_loss': 1.695892095565796, 'eval_runtime': 1.67, 'eval_samples_per_second': 25.15, 'eval_steps_per_second': 1.198, 'epoch': 25.87}


  return fn(*args, **kwargs)
 33%|███▎      | 100/300 [23:50<46:56, 14.08s/it]

{'loss': 1.8459, 'learning_rate': 1.5922352526649803e-05, 'epoch': 26.67}


                                                 
 34%|███▎      | 101/300 [24:06<46:01, 13.88s/it]

{'eval_loss': 1.6963555812835693, 'eval_runtime': 1.5542, 'eval_samples_per_second': 27.023, 'eval_steps_per_second': 1.287, 'epoch': 26.93}


  return fn(*args, **kwargs)
                                                 
 35%|███▌      | 105/300 [25:00<43:31, 13.39s/it]

{'eval_loss': 1.7023065090179443, 'eval_runtime': 1.6271, 'eval_samples_per_second': 25.812, 'eval_steps_per_second': 1.229, 'epoch': 28.0}


  return fn(*args, **kwargs)
                                                 
 36%|███▌      | 108/300 [25:52<44:25, 13.88s/it]

{'eval_loss': 1.7052775621414185, 'eval_runtime': 1.5309, 'eval_samples_per_second': 27.435, 'eval_steps_per_second': 1.306, 'epoch': 28.8}


  return fn(*args, **kwargs)
 37%|███▋      | 110/300 [26:13<46:30, 14.69s/it]

{'loss': 1.8082, 'learning_rate': 1.5000000000000002e-05, 'epoch': 29.33}


                                                 
 37%|███▋      | 112/300 [26:46<44:48, 14.30s/it]

{'eval_loss': 1.7082090377807617, 'eval_runtime': 1.6577, 'eval_samples_per_second': 25.337, 'eval_steps_per_second': 1.207, 'epoch': 29.87}


  return fn(*args, **kwargs)
                                                 
 39%|███▊      | 116/300 [27:39<42:14, 13.77s/it]

{'eval_loss': 1.707259178161621, 'eval_runtime': 1.8151, 'eval_samples_per_second': 23.139, 'eval_steps_per_second': 1.102, 'epoch': 30.93}


  return fn(*args, **kwargs)
 40%|████      | 120/300 [28:31<40:20, 13.44s/it]

{'loss': 1.8021, 'learning_rate': 1.4016954246529697e-05, 'epoch': 32.0}


                                                 
 40%|████      | 120/300 [28:33<40:20, 13.44s/it]

{'eval_loss': 1.7056208848953247, 'eval_runtime': 1.4186, 'eval_samples_per_second': 29.606, 'eval_steps_per_second': 1.41, 'epoch': 32.0}


  return fn(*args, **kwargs)
                                                 
 41%|████      | 123/300 [29:26<40:48, 13.83s/it]

{'eval_loss': 1.7048007249832153, 'eval_runtime': 1.5312, 'eval_samples_per_second': 27.429, 'eval_steps_per_second': 1.306, 'epoch': 32.8}


  return fn(*args, **kwargs)
                                                 
 42%|████▏     | 127/300 [30:20<41:42, 14.47s/it]

{'eval_loss': 1.7089468240737915, 'eval_runtime': 1.8336, 'eval_samples_per_second': 22.906, 'eval_steps_per_second': 1.091, 'epoch': 33.87}


  return fn(*args, **kwargs)
 43%|████▎     | 130/300 [30:58<40:35, 14.33s/it]

{'loss': 1.8131, 'learning_rate': 1.2985148110016947e-05, 'epoch': 34.67}


                                                 
 44%|████▎     | 131/300 [31:14<39:03, 13.87s/it]

{'eval_loss': 1.7102563381195068, 'eval_runtime': 1.6239, 'eval_samples_per_second': 25.864, 'eval_steps_per_second': 1.232, 'epoch': 34.93}


  return fn(*args, **kwargs)
                                                 
 45%|████▌     | 135/300 [32:07<37:03, 13.48s/it]

{'eval_loss': 1.7104980945587158, 'eval_runtime': 1.6372, 'eval_samples_per_second': 25.653, 'eval_steps_per_second': 1.222, 'epoch': 36.0}


  return fn(*args, **kwargs)
                                                 
 46%|████▌     | 138/300 [33:01<37:47, 13.99s/it]

{'eval_loss': 1.7109707593917847, 'eval_runtime': 1.9888, 'eval_samples_per_second': 21.118, 'eval_steps_per_second': 1.006, 'epoch': 36.8}


  return fn(*args, **kwargs)
 47%|████▋     | 140/300 [33:22<39:56, 14.98s/it]

{'loss': 1.7739, 'learning_rate': 1.1917106319237386e-05, 'epoch': 37.33}


                                                 
 47%|████▋     | 142/300 [33:55<37:57, 14.41s/it]

{'eval_loss': 1.711273431777954, 'eval_runtime': 1.4598, 'eval_samples_per_second': 28.772, 'eval_steps_per_second': 1.37, 'epoch': 37.87}


  return fn(*args, **kwargs)
                                                 
 49%|████▊     | 146/300 [34:48<35:20, 13.77s/it]

{'eval_loss': 1.7140793800354004, 'eval_runtime': 1.5134, 'eval_samples_per_second': 27.753, 'eval_steps_per_second': 1.322, 'epoch': 38.93}


  return fn(*args, **kwargs)
 50%|█████     | 150/300 [35:40<33:29, 13.39s/it]

{'loss': 1.765, 'learning_rate': 1.0825793454723325e-05, 'epoch': 40.0}


                                                 
 50%|█████     | 150/300 [35:42<33:29, 13.39s/it]

{'eval_loss': 1.7137072086334229, 'eval_runtime': 1.6215, 'eval_samples_per_second': 25.903, 'eval_steps_per_second': 1.233, 'epoch': 40.0}


  return fn(*args, **kwargs)
                                                 
 51%|█████     | 153/300 [36:36<34:00, 13.88s/it]

{'eval_loss': 1.7134851217269897, 'eval_runtime': 1.679, 'eval_samples_per_second': 25.014, 'eval_steps_per_second': 1.191, 'epoch': 40.8}


  return fn(*args, **kwargs)
                                                 
 52%|█████▏    | 157/300 [37:30<34:15, 14.37s/it]

{'eval_loss': 1.7138193845748901, 'eval_runtime': 1.7637, 'eval_samples_per_second': 23.814, 'eval_steps_per_second': 1.134, 'epoch': 41.87}


  return fn(*args, **kwargs)
 53%|█████▎    | 160/300 [38:07<33:06, 14.19s/it]

{'loss': 1.7844, 'learning_rate': 9.724456576318383e-06, 'epoch': 42.67}


                                                 
 54%|█████▎    | 161/300 [38:23<32:05, 13.85s/it]

{'eval_loss': 1.71701979637146, 'eval_runtime': 1.8674, 'eval_samples_per_second': 22.491, 'eval_steps_per_second': 1.071, 'epoch': 42.93}


  return fn(*args, **kwargs)
                                                 
 55%|█████▌    | 165/300 [39:18<30:22, 13.50s/it]

{'eval_loss': 1.7173590660095215, 'eval_runtime': 1.8804, 'eval_samples_per_second': 22.335, 'eval_steps_per_second': 1.064, 'epoch': 44.0}


  return fn(*args, **kwargs)
                                                 
 56%|█████▌    | 168/300 [40:11<30:38, 13.93s/it]

{'eval_loss': 1.7167714834213257, 'eval_runtime': 1.5351, 'eval_samples_per_second': 27.36, 'eval_steps_per_second': 1.303, 'epoch': 44.8}


  return fn(*args, **kwargs)
 57%|█████▋    | 170/300 [40:32<32:22, 14.94s/it]

{'loss': 1.7502, 'learning_rate': 8.626464421815919e-06, 'epoch': 45.33}


                                                 
 57%|█████▋    | 172/300 [41:05<30:41, 14.38s/it]

{'eval_loss': 1.7175542116165161, 'eval_runtime': 1.9289, 'eval_samples_per_second': 21.774, 'eval_steps_per_second': 1.037, 'epoch': 45.87}


  return fn(*args, **kwargs)
                                                 
 59%|█████▊    | 176/300 [41:58<28:23, 13.73s/it]

{'eval_loss': 1.7188231945037842, 'eval_runtime': 1.6934, 'eval_samples_per_second': 24.802, 'eval_steps_per_second': 1.181, 'epoch': 46.93}


  return fn(*args, **kwargs)
 60%|██████    | 180/300 [42:50<26:52, 13.44s/it]

{'loss': 1.7391, 'learning_rate': 7.545145128592009e-06, 'epoch': 48.0}


                                                 
 60%|██████    | 180/300 [42:52<26:52, 13.44s/it]

{'eval_loss': 1.7187964916229248, 'eval_runtime': 1.603, 'eval_samples_per_second': 26.201, 'eval_steps_per_second': 1.248, 'epoch': 48.0}


  return fn(*args, **kwargs)
                                                 
 61%|██████    | 183/300 [43:45<27:05, 13.90s/it]

{'eval_loss': 1.7177810668945312, 'eval_runtime': 1.7277, 'eval_samples_per_second': 24.31, 'eval_steps_per_second': 1.158, 'epoch': 48.8}


  return fn(*args, **kwargs)
                                                 
 62%|██████▏   | 187/300 [44:39<26:57, 14.32s/it]

{'eval_loss': 1.7189439535140991, 'eval_runtime': 1.6623, 'eval_samples_per_second': 25.266, 'eval_steps_per_second': 1.203, 'epoch': 49.87}


  return fn(*args, **kwargs)
 63%|██████▎   | 190/300 [45:16<25:56, 14.15s/it]

{'loss': 1.7425, 'learning_rate': 6.4936244480724575e-06, 'epoch': 50.67}


                                                 
 64%|██████▎   | 191/300 [45:33<25:14, 13.90s/it]

{'eval_loss': 1.7195024490356445, 'eval_runtime': 1.6168, 'eval_samples_per_second': 25.978, 'eval_steps_per_second': 1.237, 'epoch': 50.93}


  return fn(*args, **kwargs)
                                                 
 65%|██████▌   | 195/300 [46:26<23:27, 13.41s/it]

{'eval_loss': 1.7202470302581787, 'eval_runtime': 1.6781, 'eval_samples_per_second': 25.028, 'eval_steps_per_second': 1.192, 'epoch': 52.0}


  return fn(*args, **kwargs)
                                                 
 66%|██████▌   | 198/300 [47:20<23:46, 13.99s/it]

{'eval_loss': 1.720544695854187, 'eval_runtime': 1.9928, 'eval_samples_per_second': 21.076, 'eval_steps_per_second': 1.004, 'epoch': 52.8}


  return fn(*args, **kwargs)
 67%|██████▋   | 200/300 [47:42<25:12, 15.12s/it]

{'loss': 1.7466, 'learning_rate': 5.484666416891109e-06, 'epoch': 53.33}


                                                 
 67%|██████▋   | 202/300 [48:14<23:31, 14.41s/it]

{'eval_loss': 1.7194780111312866, 'eval_runtime': 1.6283, 'eval_samples_per_second': 25.794, 'eval_steps_per_second': 1.228, 'epoch': 53.87}


  return fn(*args, **kwargs)
                                                 
 69%|██████▊   | 206/300 [49:08<21:41, 13.84s/it]

{'eval_loss': 1.719663381576538, 'eval_runtime': 1.6571, 'eval_samples_per_second': 25.345, 'eval_steps_per_second': 1.207, 'epoch': 54.93}


  return fn(*args, **kwargs)
 70%|███████   | 210/300 [49:59<19:52, 13.26s/it]

{'loss': 1.7256, 'learning_rate': 4.530518418775734e-06, 'epoch': 56.0}


                                                 
 70%|███████   | 210/300 [50:01<19:52, 13.26s/it]

{'eval_loss': 1.7206614017486572, 'eval_runtime': 1.9562, 'eval_samples_per_second': 21.47, 'eval_steps_per_second': 1.022, 'epoch': 56.0}


  return fn(*args, **kwargs)
                                                 
 71%|███████   | 213/300 [50:55<20:18, 14.01s/it]

{'eval_loss': 1.7211276292800903, 'eval_runtime': 1.5839, 'eval_samples_per_second': 26.516, 'eval_steps_per_second': 1.263, 'epoch': 56.8}


  return fn(*args, **kwargs)
                                                 
 72%|███████▏  | 217/300 [51:49<19:58, 14.44s/it]

{'eval_loss': 1.7211543321609497, 'eval_runtime': 1.5551, 'eval_samples_per_second': 27.008, 'eval_steps_per_second': 1.286, 'epoch': 57.87}


  return fn(*args, **kwargs)
 73%|███████▎  | 220/300 [52:27<18:58, 14.23s/it]

{'loss': 1.7425, 'learning_rate': 3.6427625179003223e-06, 'epoch': 58.67}


                                                 
 74%|███████▎  | 221/300 [52:43<18:08, 13.78s/it]

{'eval_loss': 1.7212646007537842, 'eval_runtime': 1.8979, 'eval_samples_per_second': 22.129, 'eval_steps_per_second': 1.054, 'epoch': 58.93}


  return fn(*args, **kwargs)
                                                 
 75%|███████▌  | 225/300 [53:37<16:54, 13.52s/it]

{'eval_loss': 1.7219167947769165, 'eval_runtime': 1.705, 'eval_samples_per_second': 24.633, 'eval_steps_per_second': 1.173, 'epoch': 60.0}


  return fn(*args, **kwargs)
                                                 
 76%|███████▌  | 228/300 [54:31<16:50, 14.03s/it]

{'eval_loss': 1.722355604171753, 'eval_runtime': 1.6446, 'eval_samples_per_second': 25.538, 'eval_steps_per_second': 1.216, 'epoch': 60.8}


  return fn(*args, **kwargs)
 77%|███████▋  | 230/300 [54:52<17:38, 15.11s/it]

{'loss': 1.7093, 'learning_rate': 2.8321748683154893e-06, 'epoch': 61.33}


                                                 
 77%|███████▋  | 232/300 [55:25<16:19, 14.40s/it]

{'eval_loss': 1.7221194505691528, 'eval_runtime': 1.8364, 'eval_samples_per_second': 22.871, 'eval_steps_per_second': 1.089, 'epoch': 61.87}


  return fn(*args, **kwargs)
                                                 
 79%|███████▊  | 236/300 [56:18<14:43, 13.80s/it]

{'eval_loss': 1.7225074768066406, 'eval_runtime': 1.6275, 'eval_samples_per_second': 25.807, 'eval_steps_per_second': 1.229, 'epoch': 62.93}


  return fn(*args, **kwargs)
 80%|████████  | 240/300 [57:11<13:33, 13.56s/it]

{'loss': 1.7115, 'learning_rate': 2.1085949060360654e-06, 'epoch': 64.0}


                                                 
 80%|████████  | 240/300 [57:12<13:33, 13.56s/it]

{'eval_loss': 1.7230054140090942, 'eval_runtime': 1.6367, 'eval_samples_per_second': 25.661, 'eval_steps_per_second': 1.222, 'epoch': 64.0}


  return fn(*args, **kwargs)
                                                 
 81%|████████  | 243/300 [58:06<13:11, 13.89s/it]

{'eval_loss': 1.723283290863037, 'eval_runtime': 1.8305, 'eval_samples_per_second': 22.944, 'eval_steps_per_second': 1.093, 'epoch': 64.8}


  return fn(*args, **kwargs)
                                                 
 82%|████████▏ | 247/300 [58:59<12:44, 14.42s/it]

{'eval_loss': 1.7235742807388306, 'eval_runtime': 1.6351, 'eval_samples_per_second': 25.686, 'eval_steps_per_second': 1.223, 'epoch': 65.87}


  return fn(*args, **kwargs)
 83%|████████▎ | 250/300 [59:37<11:51, 14.23s/it]

{'loss': 1.7336, 'learning_rate': 1.4808059116167306e-06, 'epoch': 66.67}


                                                 
 84%|████████▎ | 251/300 [59:53<11:22, 13.93s/it]

{'eval_loss': 1.7235395908355713, 'eval_runtime': 1.7565, 'eval_samples_per_second': 23.912, 'eval_steps_per_second': 1.139, 'epoch': 66.93}


  return fn(*args, **kwargs)
                                                   
 85%|████████▌ | 255/300 [1:00:48<10:08, 13.53s/it]

{'eval_loss': 1.723737120628357, 'eval_runtime': 1.8593, 'eval_samples_per_second': 22.589, 'eval_steps_per_second': 1.076, 'epoch': 68.0}


  return fn(*args, **kwargs)
                                                   
 86%|████████▌ | 258/300 [1:01:42<09:52, 14.10s/it]

{'eval_loss': 1.7239004373550415, 'eval_runtime': 1.5463, 'eval_samples_per_second': 27.162, 'eval_steps_per_second': 1.293, 'epoch': 68.8}


  return fn(*args, **kwargs)
 87%|████████▋ | 260/300 [1:02:03<10:10, 15.27s/it]

{'loss': 1.7134, 'learning_rate': 9.564283930242258e-07, 'epoch': 69.33}


                                                   
 87%|████████▋ | 262/300 [1:02:35<09:11, 14.52s/it]

{'eval_loss': 1.7238432168960571, 'eval_runtime': 1.5629, 'eval_samples_per_second': 26.873, 'eval_steps_per_second': 1.28, 'epoch': 69.87}


  return fn(*args, **kwargs)
                                                   
 89%|████████▊ | 266/300 [1:03:29<07:51, 13.87s/it]

{'eval_loss': 1.7238762378692627, 'eval_runtime': 1.7855, 'eval_samples_per_second': 23.522, 'eval_steps_per_second': 1.12, 'epoch': 70.93}


  return fn(*args, **kwargs)
 90%|█████████ | 270/300 [1:04:21<06:44, 13.47s/it]

{'loss': 1.7127, 'learning_rate': 5.418275829936537e-07, 'epoch': 72.0}


                                                   
 90%|█████████ | 270/300 [1:04:23<06:44, 13.47s/it]

{'eval_loss': 1.7238558530807495, 'eval_runtime': 1.5346, 'eval_samples_per_second': 27.369, 'eval_steps_per_second': 1.303, 'epoch': 72.0}


  return fn(*args, **kwargs)
                                                   
 91%|█████████ | 273/300 [1:05:16<06:14, 13.88s/it]

{'eval_loss': 1.7239043712615967, 'eval_runtime': 1.58, 'eval_samples_per_second': 26.582, 'eval_steps_per_second': 1.266, 'epoch': 72.8}


  return fn(*args, **kwargs)
                                                   
 92%|█████████▏| 277/300 [1:06:10<05:30, 14.36s/it]

{'eval_loss': 1.7237932682037354, 'eval_runtime': 1.6798, 'eval_samples_per_second': 25.003, 'eval_steps_per_second': 1.191, 'epoch': 73.87}


  return fn(*args, **kwargs)
 93%|█████████▎| 280/300 [1:06:48<04:45, 14.25s/it]

{'loss': 1.7384, 'learning_rate': 2.420361737256438e-07, 'epoch': 74.67}


                                                   
 94%|█████████▎| 281/300 [1:07:04<04:24, 13.92s/it]

{'eval_loss': 1.723901629447937, 'eval_runtime': 1.8386, 'eval_samples_per_second': 22.843, 'eval_steps_per_second': 1.088, 'epoch': 74.93}


  return fn(*args, **kwargs)
                                                   
 95%|█████████▌| 285/300 [1:07:58<03:22, 13.49s/it]

{'eval_loss': 1.723875880241394, 'eval_runtime': 1.5238, 'eval_samples_per_second': 27.562, 'eval_steps_per_second': 1.312, 'epoch': 76.0}


  return fn(*args, **kwargs)
                                                   
 96%|█████████▌| 288/300 [1:08:52<02:48, 14.05s/it]

{'eval_loss': 1.723939061164856, 'eval_runtime': 1.6375, 'eval_samples_per_second': 25.649, 'eval_steps_per_second': 1.221, 'epoch': 76.8}


  return fn(*args, **kwargs)
 97%|█████████▋| 290/300 [1:09:13<02:30, 15.07s/it]

{'loss': 1.7046, 'learning_rate': 6.069322682050516e-08, 'epoch': 77.33}


                                                   
 97%|█████████▋| 292/300 [1:09:46<01:55, 14.47s/it]

{'eval_loss': 1.7239011526107788, 'eval_runtime': 1.8006, 'eval_samples_per_second': 23.326, 'eval_steps_per_second': 1.111, 'epoch': 77.87}


  return fn(*args, **kwargs)
                                                   
 99%|█████████▊| 296/300 [1:10:40<00:55, 13.92s/it]

{'eval_loss': 1.7239590883255005, 'eval_runtime': 1.6439, 'eval_samples_per_second': 25.549, 'eval_steps_per_second': 1.217, 'epoch': 78.93}


  return fn(*args, **kwargs)
100%|██████████| 300/300 [1:11:31<00:00, 13.41s/it]

{'loss': 1.7099, 'learning_rate': 0.0, 'epoch': 80.0}


                                                   
100%|██████████| 300/300 [1:11:33<00:00, 13.41s/it]

{'eval_loss': 1.723951816558838, 'eval_runtime': 1.4001, 'eval_samples_per_second': 29.999, 'eval_steps_per_second': 1.429, 'epoch': 80.0}


100%|██████████| 300/300 [1:11:33<00:00, 14.31s/it]


{'train_runtime': 4293.6793, 'train_samples_per_second': 7.919, 'train_steps_per_second': 0.07, 'train_loss': 1.8776977920532227, 'epoch': 80.0}
Model saved to ./immigration_assistant_final
Testing model on examples...
Loading base model for comparison...


  return torch.load(checkpoint_file, map_location="cpu")


Loading fine-tuned model...


  adapters_weights = torch.load(



Generating responses from base and fine-tuned models...


--- Example 1 ---
Question: I received a Notice of Intent to Deny (NOID) my case from the government. What can I do?
Reference Answer: Many times the government improperly concludes that a case is deniable. Our experienced attorneys have successfully resolved cases in which the government intends to deny the case. While results may vary depending upon fact patterns and a case cannot always be resolved, a consultation with an attorney may turn up another avenue of relief.
Base Model Output: If you receive an NOI, it means that your petition has been denied and will not be processed by USCIS for further action until we have reviewed the reasons why our decision was made in error or if we decide to reopen the matter with additional evidence. We may also send out additional notices related to pending cases at any time without prior notice. Please note that there is no guarantee that these notifications will provide us with more inf

In [8]:
import os
import json
import glob
import re
import pandas as pd
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import math

# ==================== CONFIGURATION ====================
# Set paths
BASE_DIRECTORY = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"
DATASET_PATH = "./immigration_qa_dataset_clean"
OUTPUT_DIR = "./immigration_assistant_model_final"
FINAL_MODEL_PATH = "./immigration_assistant_final"
RESULTS_CSV = "./model_comparison_results.csv"
METRICS_CSV = "./model_evaluation_metrics.csv"

# Open-access model that doesn't require authentication
MODEL_ID = "facebook/opt-1.3b"  # 1.3B parameters, open access

# Training parameters
EPOCHS = 100
BATCH_SIZE = 6
LEARNING_RATE = 2e-5
LORA_RANK = 32
LORA_ALPHA = 64

# ==================== HELPER FUNCTIONS ====================
def clean_text(text):
    """Clean text by removing question/answer prefixes and extra whitespace."""
    # Remove "Q." or "Q#." prefixes from questions
    text = re.sub(r'^Q\.?\s*\d*\.?\s*', '', text)
    # Remove "A." or "A#." prefixes from answers
    text = re.sub(r'^A\.?\s*\d*\.?\s*', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_clean_json_files(directory_path):
    """Load and clean all JSON files in the directory."""
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers or questions
                    if not question or not answer:
                        continue
                    
                    # Clean the texts
                    question = clean_text(question)
                    answer = clean_text(answer)
                    
                    # Skip very short answers (likely not useful)
                    if len(answer) < 20:
                        continue
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} clean question-answer pairs")
    return all_data

def post_process_response(text):
    """Clean model outputs by removing repetitions and known artifacts."""
    # Remove irrelevant prefix text
    if "Question:" in text and "Answer:" in text:
        text = text.split("Answer:", 1)[1].strip()
    
    # Split by lines and remove duplicates while preserving order
    lines = text.split('\n')
    seen_texts = set()
    unique_lines = []
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        # Skip duplicate content
        if line in seen_texts:
            continue
            
        # Skip lines that are question-like
        if line.lower().startswith(("question:", "q:", "what is", "how do", "can i")):
            continue
            
        seen_texts.add(line)
        unique_lines.append(line)
    
    # Join unique lines
    processed_text = '\n'.join(unique_lines)
    
    # If we filtered too much, return the original without duplicates
    if len(processed_text) < 20 and len(text) > 20:
        lines = text.split('\n')
        seen_texts = set()
        unique_lines = []
        for line in lines:
            if line.strip() and line.strip() not in seen_texts:
                seen_texts.add(line.strip())
                unique_lines.append(line)
        processed_text = '\n'.join(unique_lines)
    
    return processed_text

# ==================== ENHANCED EVALUATION FUNCTIONS ====================

def calculate_perplexity(model, tokenizer, texts, max_length=512):
    """
    Calculate perplexity of texts using the given model.
    Lower perplexity indicates better predictive performance.
    """
    model.eval()
    total_perplexity = 0
    
    with torch.no_grad():
        for text in texts:
            # Tokenize input text
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            # Forward pass
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            
            # Calculate perplexity
            perplexity = math.exp(loss.item())
            total_perplexity += perplexity
    
    # Return average perplexity across all texts
    return total_perplexity / len(texts)

def calculate_tfidf_cosine_similarity(predictions, references):
    """
    Calculate cosine similarity between predictions and references using TF-IDF.
    Higher values indicate more similar content.
    """
    # Initialize TF-IDF vectorizer
    vectorizer = TfidfVectorizer()
    
    # Combine all texts to fit the vectorizer
    all_texts = predictions + references
    vectorizer.fit(all_texts)
    
    # Transform texts to TF-IDF vectors
    pred_vectors = vectorizer.transform(predictions)
    ref_vectors = vectorizer.transform(references)
    
    # Calculate cosine similarities
    similarities = []
    for i in range(len(predictions)):
        pred_vector = pred_vectors[i:i+1]  # Get the i-th prediction vector
        ref_vector = ref_vectors[i:i+1]    # Get the i-th reference vector
        similarity = cosine_similarity(pred_vector, ref_vector)[0][0]
        similarities.append(similarity)
    
    # Return average similarity
    return sum(similarities) / len(similarities)

def calculate_bleu_score(predictions, references):
    """
    Calculate BLEU score for predictions against references.
    Higher scores indicate better overlap in n-grams.
    """
    bleu = evaluate.load('bleu')
    
    # Format references as list of lists (BLEU expects multiple references format)
    formatted_references = [[ref] for ref in references]
    
    # Calculate BLEU score
    results = bleu.compute(predictions=predictions, references=formatted_references)
    
    return results['bleu']

def calculate_meteor_score(predictions, references):
    """
    Calculate METEOR score for predictions against references.
    METEOR is a metric that considers synonyms and paraphrasing.
    """
    try:
        meteor = evaluate.load('meteor')
        results = meteor.compute(predictions=predictions, references=references)
        return results['meteor']
    except:
        print("METEOR metric could not be loaded or computed. Skipping.")
        return None

def calculate_word_error_rate(predictions, references):
    """
    Calculate Word Error Rate between predictions and references.
    Lower WER indicates fewer word-level differences.
    """
    try:
        wer = evaluate.load('wer')
        results = wer.compute(predictions=predictions, references=references)
        return results
    except:
        print("WER metric could not be loaded or computed. Skipping.")
        return None

def evaluate_model_comprehensive(model, tokenizer, questions, reference_answers, 
                                model_name="Model", max_length=512, temperature=0.7):
    """
    Comprehensive evaluation of model outputs using multiple metrics.
    """
    print(f"\nEvaluating {model_name}...")
    generated_outputs = []
    
    # Generate responses from model
    for question in questions:
        # Format prompt for the model
        prompt = f"### Instruction: You are an immigration assistant. Provide accurate information about this question: {question}\n\n### Response:"
        
        # Generate output
        input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).input_ids.to(model.device)
        outputs = model.generate(
            input_ids=input_ids, 
            max_length=len(input_ids[0]) + 250, 
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3
        )
        
        # Decode and clean output
        text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        text = text.replace(prompt, "").strip()
        text = post_process_response(text)
        generated_outputs.append(text)
    
    # Calculate metrics
    metrics = {}
    
    # 1. ROUGE scores
    rouge = evaluate.load('rouge')
    rouge_results = rouge.compute(
        predictions=generated_outputs,
        references=reference_answers,
        use_stemmer=True
    )
    metrics.update({f"rouge_{k}": v for k, v in rouge_results.items()})
    
    # 2. TF-IDF Cosine Similarity
    cos_sim = calculate_tfidf_cosine_similarity(generated_outputs, reference_answers)
    metrics['tfidf_cosine_similarity'] = cos_sim
    
    # 3. BLEU Score
    bleu = calculate_bleu_score(generated_outputs, reference_answers)
    metrics['bleu'] = bleu
    
    # 4. METEOR Score (if available)
    meteor = calculate_meteor_score(generated_outputs, reference_answers)
    if meteor is not None:
        metrics['meteor'] = meteor
    
    # 5. Word Error Rate (if available)
    wer = calculate_word_error_rate(generated_outputs, reference_answers)
    if wer is not None:
        metrics['wer'] = wer
    
    # 6. Perplexity (on reference answers)
    try:
        perplexity = calculate_perplexity(model, tokenizer, reference_answers)
        metrics['perplexity'] = perplexity
    except Exception as e:
        print(f"Error calculating perplexity: {e}")
        metrics['perplexity'] = None
    
    return metrics, generated_outputs

# ==================== MAIN SCRIPT ====================

def main():
    # Check if dataset already exists
    if os.path.exists(DATASET_PATH):
        print(f"Dataset found at {DATASET_PATH}, loading...")
        dataset_dict = DatasetDict.load_from_disk(DATASET_PATH)
        
        # Convert to DataFrames for later use
        train_df = pd.DataFrame(dataset_dict['train'])
        val_df = pd.DataFrame(dataset_dict['validation'])
        test_df = pd.DataFrame(dataset_dict['test'])
        
        print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")
    else:
        # Step 1: Load and prepare data
        print("Loading and cleaning data...")
        all_qa_data = load_and_clean_json_files(BASE_DIRECTORY)

        # Convert to DataFrame
        df = pd.DataFrame(all_qa_data)
        print(f"Dataset shape: {df.shape}")
        print("Sample data:")
        print(df.head(2))

        # Split into train, validation, and test sets (80%, 10%, 10%)
        train_size = int(0.8 * len(df))
        val_size = int(0.1 * len(df))

        train_df = df[:train_size]
        val_df = df[train_size:train_size+val_size]
        test_df = df[train_size+val_size:]

        print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

        # Convert to Hugging Face datasets
        train_dataset = Dataset.from_pandas(train_df)
        val_dataset = Dataset.from_pandas(val_df)
        test_dataset = Dataset.from_pandas(test_df)

        # Combine into a dataset dictionary
        dataset_dict = DatasetDict({
            'train': train_dataset,
            'validation': val_dataset,
            'test': test_dataset
        })

        # Save the clean dataset to disk
        os.makedirs(DATASET_PATH, exist_ok=True)
        dataset_dict.save_to_disk(DATASET_PATH)
        print(f"Dataset saved to {DATASET_PATH}")

    # Step 2: Load Model and Tokenizer
    # Define quantization config for 4-bit precision
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False,
    )

    # Load tokenizer
    print(f"Loading tokenizer for {MODEL_ID}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # For evaluation only - checking if trained model exists
    if os.path.exists(FINAL_MODEL_PATH):
        print("Trained model found. Proceeding to evaluation...")
        
        # Load base model for comparison
        print("Loading base model for comparison...")
        base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            quantization_config=bnb_config,
            device_map={"": 0}
        )

        # Load fine-tuned model (PEFT)
        print("Loading fine-tuned model...")
        peft_model = PeftModel.from_pretrained(
            base_model,
            FINAL_MODEL_PATH,
            device_map={"": 0}
        )

        # Test on examples from test set
        num_test_examples = min(15, len(test_df))  # Use at most 15 examples to keep evaluation time reasonable
        test_questions = test_df['Question'][:num_test_examples].tolist()
        test_answers = test_df['Answer'][:num_test_examples].tolist()

        # Perform comprehensive evaluation
        base_metrics, base_outputs = evaluate_model_comprehensive(
            base_model, tokenizer, test_questions, test_answers, model_name="Base Model"
        )
        
        ft_metrics, ft_outputs = evaluate_model_comprehensive(
            peft_model, tokenizer, test_questions, test_answers, model_name="Fine-tuned Model"
        )

        # Save and display results
        print("\n--- Comprehensive Evaluation Results ---")
        print("\nBase Model Metrics:")
        for k, v in base_metrics.items():
            print(f"{k}: {v}")
        
        print("\nFine-tuned Model Metrics:")
        for k, v in ft_metrics.items():
            print(f"{k}: {v}")
        
        # Create results DataFrame
        metrics_df = pd.DataFrame({
            'Metric': list(base_metrics.keys()),
            'Base Model': list(base_metrics.values()),
            'Fine-tuned Model': list(ft_metrics.values()),
            'Difference': [ft_metrics[k] - base_metrics[k] if isinstance(base_metrics[k], (int, float)) and 
                          isinstance(ft_metrics[k], (int, float)) else None 
                          for k in base_metrics.keys()]
        })
        
        # Save to CSV
        metrics_df.to_csv(METRICS_CSV, index=False)
        print(f"\nSaved metrics results to {METRICS_CSV}")

        # Save examples for qualitative analysis
        examples_df = pd.DataFrame({
            "Question": test_questions,
            "Reference_Answer": test_answers,
            "Base_Model_Output": base_outputs,
            "Fine_Tuned_Output": ft_outputs
        })
        
        examples_df.to_csv(RESULTS_CSV, index=False)
        print(f"Saved comparison results to {RESULTS_CSV}")

        # Print a few examples for quick reference
        print("\n--- Example Outputs ---")
        for i in range(min(3, len(test_questions))):
            print(f"\n\n--- Example {i+1} ---")
            print(f"Question: {test_questions[i]}")
            print(f"Reference Answer: {test_answers[i]}")
            print(f"Base Model Output: {base_outputs[i]}")
            print(f"Fine-tuned Model Output: {ft_outputs[i]}")

        print("\nEvaluation complete!")
    else:
        print(f"Trained model not found at {FINAL_MODEL_PATH}. Please train the model first.")
        # You can add training code here if needed

if __name__ == "__main__":
    main()

Dataset found at ./immigration_qa_dataset_clean, loading...
Train size: 340, Validation size: 42, Test size: 43
Loading tokenizer for facebook/opt-1.3b...




Trained model found. Proceeding to evaluation...
Loading base model for comparison...


  return torch.load(checkpoint_file, map_location="cpu")


Loading fine-tuned model...


  adapters_weights = torch.load(



Evaluating Base Model...


Downloading builder script: 100%|██████████| 5.94k/5.94k [00:00<00:00, 7.86MB/s]
Downloading extra modules: 4.07kB [00:00, 4.35MB/s]                   
Downloading extra modules: 100%|██████████| 3.34k/3.34k [00:00<00:00, 9.66MB/s]
Downloading builder script: 100%|██████████| 6.81k/6.81k [00:00<00:00, 10.9MB/s]


METEOR metric could not be loaded or computed. Skipping.


Downloading builder script: 100%|██████████| 4.49k/4.49k [00:00<00:00, 10.0MB/s]


WER metric could not be loaded or computed. Skipping.

Evaluating Fine-tuned Model...
METEOR metric could not be loaded or computed. Skipping.
WER metric could not be loaded or computed. Skipping.

--- Comprehensive Evaluation Results ---

Base Model Metrics:
rouge_rouge1: 0.22484908760582661
rouge_rouge2: 0.022768048744514623
rouge_rougeL: 0.10347650150268695
rouge_rougeLsum: 0.10461042494637657
tfidf_cosine_similarity: 0.17600273056830965
bleu: 0.0077465861856109925
perplexity: 30.44202381033492

Fine-tuned Model Metrics:
rouge_rouge1: 0.24380942214897616
rouge_rouge2: 0.025924972632562833
rouge_rougeL: 0.1133994902769046
rouge_rougeLsum: 0.11344463427783627
tfidf_cosine_similarity: 0.21276319254850184
bleu: 0.007051490908977086
perplexity: 30.44202381033492

Saved metrics results to ./model_evaluation_metrics.csv
Saved comparison results to ./model_comparison_results.csv

--- Example Outputs ---


--- Example 1 ---
Question: I received a Notice of Intent to Deny (NOID) my case from 

In [5]:
!pip install scikit-learn

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [13]:
import os
import json
import glob
import re
import pandas as pd
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer
import evaluate
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import math

# ==================== CONFIGURATION ====================
# Set paths
BASE_DIRECTORY = "/home/hailemicaelyimer/Desktop/immigration-assistant/frequently-asked-questions"
DATASET_PATH = "./immigration_qa_dataset_clean"
OUTPUT_DIR = "./immigration_assistant_gemma_model"
FINAL_MODEL_PATH = "./immigration_assistant_gemma_final"
RESULTS_CSV = "./gemma_model_comparison_results.csv"
METRICS_CSV = "./gemma_model_evaluation_metrics.csv"

# Use OPT-1.3B as a fallback since it's already working for you
MODEL_ID = "facebook/opt-1.3b"  # Use your original model that worked

# Training parameters
EPOCHS = 30
BATCH_SIZE = 6  # Use your original batch size
LEARNING_RATE = 2e-5
LORA_RANK = 32
LORA_ALPHA = 64

# ==================== HELPER FUNCTIONS ====================
def clean_text(text):
    """Clean text by removing question/answer prefixes and extra whitespace."""
    # Remove "Q." or "Q#." prefixes from questions
    text = re.sub(r'^Q\.?\s*\d*\.?\s*', '', text)
    # Remove "A." or "A#." prefixes from answers
    text = re.sub(r'^A\.?\s*\d*\.?\s*', '', text)
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def load_and_clean_json_files(directory_path):
    """Load and clean all JSON files in the directory."""
    json_files = glob.glob(os.path.join(directory_path, "**/*.json"), recursive=True)
    
    all_data = []
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
                for item in data:
                    question = item.get("question", "").strip()
                    answer = item.get("answer", "").strip()
                    
                    # Skip items with empty answers or questions
                    if not question or not answer:
                        continue
                    
                    # Clean the texts
                    question = clean_text(question)
                    answer = clean_text(answer)
                    
                    # Skip very short answers (likely not useful)
                    if len(answer) < 20:
                        continue
                    
                    all_data.append({
                        "Question": question,
                        "Answer": answer
                    })
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
    
    print(f"Loaded {len(all_data)} clean question-answer pairs")
    return all_data

def post_process_response(text):
    """Clean model outputs by removing repetitions and known artifacts."""
    # Remove irrelevant prefix text
    if "Question:" in text and "Answer:" in text:
        text = text.split("Answer:", 1)[1].strip()
    
    # Split by lines and remove duplicates while preserving order
    lines = text.split('\n')
    seen_texts = set()
    unique_lines = []
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        # Skip duplicate content
        if line in seen_texts:
            continue
            
        # Skip lines that are question-like
        if line.lower().startswith(("question:", "q:", "what is", "how do", "can i")):
            continue
            
        seen_texts.add(line)
        unique_lines.append(line)
    
    # Join unique lines
    processed_text = '\n'.join(unique_lines)
    
    # If we filtered too much, return the original without duplicates
    if len(processed_text) < 20 and len(text) > 20:
        lines = text.split('\n')
        seen_texts = set()
        unique_lines = []
        for line in lines:
            if line.strip() and line.strip() not in seen_texts:
                seen_texts.add(line.strip())
                unique_lines.append(line)
        processed_text = '\n'.join(unique_lines)
    
    return processed_text

# ==================== ENHANCED EVALUATION FUNCTIONS ====================

def calculate_perplexity(model, tokenizer, texts, max_length=512):
    """
    Calculate perplexity of texts using the given model.
    Lower perplexity indicates better predictive performance.
    """
    model.eval()
    total_perplexity = 0
    
    with torch.no_grad():
        for text in texts:
            # Tokenize input text
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            # Forward pass
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            
            # Calculate perplexity
            perplexity = math.exp(loss.item())
            total_perplexity += perplexity
    
    # Return average perplexity across all texts
    return total_perplexity / len(texts)

def calculate_tfidf_cosine_similarity(predictions, references):
    """
    Calculate cosine similarity between predictions and references using TF-IDF.
    Higher values indicate more similar content.
    """
    # Initialize TF-IDF vectorizer
    vectorizer = TfidfVectorizer()
    
    # Combine all texts to fit the vectorizer
    all_texts = predictions + references
    vectorizer.fit(all_texts)
    
    # Transform texts to TF-IDF vectors
    pred_vectors = vectorizer.transform(predictions)
    ref_vectors = vectorizer.transform(references)
    
    # Calculate cosine similarities
    similarities = []
    for i in range(len(predictions)):
        pred_vector = pred_vectors[i:i+1]  # Get the i-th prediction vector
        ref_vector = ref_vectors[i:i+1]    # Get the i-th reference vector
        similarity = cosine_similarity(pred_vector, ref_vector)[0][0]
        similarities.append(similarity)
    
    # Return average similarity
    return sum(similarities) / len(similarities)

def calculate_bleu_score(predictions, references):
    """
    Calculate BLEU score for predictions against references.
    Higher scores indicate better overlap in n-grams.
    """
    bleu = evaluate.load('bleu')
    
    # Format references as list of lists (BLEU expects multiple references format)
    formatted_references = [[ref] for ref in references]
    
    # Calculate BLEU score
    results = bleu.compute(predictions=predictions, references=formatted_references)
    
    return results['bleu']

def evaluate_model_comprehensive(model, tokenizer, questions, reference_answers, 
                                model_name="Model", max_length=512, temperature=0.7):
    """
    Comprehensive evaluation of model outputs using multiple metrics.
    """
    print(f"\nEvaluating {model_name}...")
    generated_outputs = []
    
    # Generate responses from model
    for question in questions:
        # Format prompt for the model - Using standard instruction format
        prompt = f"You are an immigration assistant. Provide accurate information about this question: {question}\n\nAnswer:"
        
        # Generate output
        input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).input_ids.to(model.device)
        outputs = model.generate(
            input_ids=input_ids, 
            max_length=len(input_ids[0]) + 250, 
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3
        )
        
        # Decode and clean output
        text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        text = text.replace(prompt, "").strip()
        text = post_process_response(text)
        generated_outputs.append(text)
    
    # Calculate metrics
    metrics = {}
    
    # 1. ROUGE scores
    rouge = evaluate.load('rouge')
    rouge_results = rouge.compute(
        predictions=generated_outputs,
        references=reference_answers,
        use_stemmer=True
    )
    metrics.update({f"rouge_{k}": v for k, v in rouge_results.items()})
    
    # 2. TF-IDF Cosine Similarity
    cos_sim = calculate_tfidf_cosine_similarity(generated_outputs, reference_answers)
    metrics['tfidf_cosine_similarity'] = cos_sim
    
    # 3. BLEU Score
    bleu = calculate_bleu_score(generated_outputs, reference_answers)
    metrics['bleu'] = bleu
    
    # 4. Perplexity (on reference answers)
    try:
        perplexity = calculate_perplexity(model, tokenizer, reference_answers)
        metrics['perplexity'] = perplexity
    except Exception as e:
        print(f"Error calculating perplexity: {e}")
        metrics['perplexity'] = None
    
    return metrics, generated_outputs

# ==================== MAIN SCRIPT ====================

def main():
    # Check if dataset already exists
    if os.path.exists(DATASET_PATH):
        print(f"Dataset found at {DATASET_PATH}, loading...")
        dataset_dict = DatasetDict.load_from_disk(DATASET_PATH)
        
        # Convert to DataFrames for later use
        train_df = pd.DataFrame(dataset_dict['train'])
        val_df = pd.DataFrame(dataset_dict['validation'])
        test_df = pd.DataFrame(dataset_dict['test'])
        
        print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")
    else:
        # Step 1: Load and prepare data
        print("Loading and cleaning data...")
        all_qa_data = load_and_clean_json_files(BASE_DIRECTORY)

        # Convert to DataFrame
        df = pd.DataFrame(all_qa_data)
        print(f"Dataset shape: {df.shape}")
        print("Sample data:")
        print(df.head(2))

        # Split into train, validation, and test sets (80%, 10%, 10%)
        train_size = int(0.8 * len(df))
        val_size = int(0.1 * len(df))

        train_df = df[:train_size]
        val_df = df[train_size:train_size+val_size]
        test_df = df[train_size+val_size:]

        print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

        # Convert to Hugging Face datasets
        train_dataset = Dataset.from_pandas(train_df)
        val_dataset = Dataset.from_pandas(val_df)
        test_dataset = Dataset.from_pandas(test_df)

        # Combine into a dataset dictionary
        dataset_dict = DatasetDict({
            'train': train_dataset,
            'validation': val_dataset,
            'test': test_dataset
        })

        # Save the clean dataset to disk
        os.makedirs(DATASET_PATH, exist_ok=True)
        dataset_dict.save_to_disk(DATASET_PATH)
        print(f"Dataset saved to {DATASET_PATH}")

    # Step 2: Load Model and Tokenizer
    # Define quantization config for 4-bit precision (with fallback to 8-bit)
    try:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=False,
        )
    except:
        print("4-bit quantization not supported, falling back to 8-bit")
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_threshold=6.0
        )

    # Load tokenizer
    print(f"Loading tokenizer for {MODEL_ID}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Load model with quantization config
    print("Loading model...")
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID, 
            quantization_config=bnb_config, 
            use_cache=False,
            device_map={"": 0}  # Use GPU 0
        )
    except Exception as e:
        print(f"Error loading model with quantization: {e}")
        print("Trying without quantization...")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            use_cache=False,
            device_map={"": 0}
        )

    # Step 3: Define preprocessing function for clean instruction format
    def preprocess_function(examples):
        # Use a simple instruction format
        formatted_prompts = [
            f"You are an immigration assistant. Provide accurate information about this question: {q}\n\nAnswer:" 
            for q in examples["Question"]
        ]
        
        return {
            "input_ids": tokenizer(
                formatted_prompts,
                truncation=True, 
                max_length=512, 
                padding="max_length"
            )["input_ids"],
            "labels": tokenizer(
                examples["Answer"], 
                truncation=True, 
                max_length=512, 
                padding="max_length"
            )["input_ids"],
            "inputs_text": [f"{prompt} {answer}" for prompt, answer in zip(formatted_prompts, examples["Answer"])],
        }

    # Apply preprocessing to datasets
    print("Preprocessing datasets...")
    processed_train_dataset = dataset_dict['train'].map(preprocess_function, batched=True)
    processed_val_dataset = dataset_dict['validation'].map(preprocess_function, batched=True)
    processed_test_dataset = dataset_dict['test'].map(preprocess_function, batched=True)

    print(f"Processed train dataset size: {len(processed_train_dataset)}")
    print(f"Processed validation dataset size: {len(processed_val_dataset)}")
    print(f"Processed test dataset size: {len(processed_test_dataset)}")

    # Step 4: Configure LoRA for efficient fine-tuning
    # Target specific modules for OPT architecture
    peft_config = LoraConfig(
        lora_alpha=LORA_ALPHA,
        lora_dropout=0.02,
        r=LORA_RANK,
        bias="none",
        task_type="CAUSAL_LM",
        # Target modules for OPT model
        target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
    )

    # Prepare model for kbit training
    print("Preparing model for training...")
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    # Step 5: Define training arguments
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    training_arguments = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        optim="paged_adamw_32bit",
        logging_steps=10,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        learning_rate=LEARNING_RATE,
        weight_decay=0.01,
        max_grad_norm=0.3,
        warmup_ratio=0.05,
        group_by_length=True,
        lr_scheduler_type="cosine",
        fp16=False,
        bf16=False,
        report_to="none",
        # Add the following to prevent repetition during training
        remove_unused_columns=False,
        label_names=["labels"],
    )

    # Data collator for language model training
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    # Step 6: Create and train the model
    trainer = SFTTrainer(
        model=model,
        train_dataset=processed_train_dataset,
        eval_dataset=processed_val_dataset,
        dataset_text_field="inputs_text",
        max_seq_length=512,
        tokenizer=tokenizer,
        args=training_arguments,
        data_collator=data_collator,
        packing=False,
    )

    # Start training
    print("Starting training...")
    trainer.train()

    # Step 7: Save the trained model
    os.makedirs(FINAL_MODEL_PATH, exist_ok=True)
    trainer.model.save_pretrained(FINAL_MODEL_PATH)
    tokenizer.save_pretrained(FINAL_MODEL_PATH)
    print(f"Model saved to {FINAL_MODEL_PATH}")

    # Step 8: Evaluate the model
    print("\nPreparing for comprehensive evaluation...")
    
    # Reload base model for comparison
    print("Loading base model for comparison...")
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map={"": 0}
    )

    # Load fine-tuned model (PEFT)
    print("Loading fine-tuned model...")
    peft_model = PeftModel.from_pretrained(
        base_model,
        FINAL_MODEL_PATH,
        device_map={"": 0}
    )

    # Test on examples from test set
    num_test_examples = min(15, len(test_df))  # Use at most 15 examples
    test_questions = test_df['Question'][:num_test_examples].tolist()
    test_answers = test_df['Answer'][:num_test_examples].tolist()

    # Perform comprehensive evaluation
    base_metrics, base_outputs = evaluate_model_comprehensive(
        base_model, tokenizer, test_questions, test_answers, model_name="Base Model"
    )
    
    ft_metrics, ft_outputs = evaluate_model_comprehensive(
        peft_model, tokenizer, test_questions, test_answers, model_name="Fine-tuned Model"
    )

    # Save and display results
    print("\n--- Comprehensive Evaluation Results ---")
    print("\nBase Model Metrics:")
    for k, v in base_metrics.items():
        print(f"{k}: {v}")
    
    print("\nFine-tuned Model Metrics:")
    for k, v in ft_metrics.items():
        print(f"{k}: {v}")
    
    # Create results DataFrame
    metrics_df = pd.DataFrame({
        'Metric': list(base_metrics.keys()),
        'Base Model': list(base_metrics.values()),
        'Fine-tuned Model': list(ft_metrics.values()),
        'Difference': [ft_metrics[k] - base_metrics[k] if isinstance(base_metrics[k], (int, float)) and 
                      isinstance(ft_metrics[k], (int, float)) else None 
                      for k in base_metrics.keys()]
    })
    
    # Save to CSV
    metrics_df.to_csv(METRICS_CSV, index=False)
    print(f"\nSaved metrics results to {METRICS_CSV}")

    # Save examples for qualitative analysis
    examples_df = pd.DataFrame({
        "Question": test_questions,
        "Reference_Answer": test_answers,
        "Base_Model_Output": base_outputs,
        "Fine_Tuned_Output": ft_outputs
    })
    
    examples_df.to_csv(RESULTS_CSV, index=False)
    print(f"Saved comparison results to {RESULTS_CSV}")

    # Print a few examples for quick reference
    print("\n--- Example Outputs ---")
    for i in range(min(3, len(test_questions))):
        print(f"\n\n--- Example {i+1} ---")
        print(f"Question: {test_questions[i]}")
        print(f"Reference Answer: {test_answers[i]}")
        print(f"Base Model Output: {base_outputs[i]}")
        print(f"Fine-tuned Model Output: {ft_outputs[i]}")

    print("\nTraining and evaluation complete!")

    # Create a simple inference function to test the model interactively
    def query_model(question, model=peft_model):
        prompt = f"You are an immigration assistant. Provide accurate information about this question: {question}\n\nAnswer:"
        input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).input_ids.to("cuda")
        
        outputs = model.generate(
            input_ids=input_ids, 
            max_length=len(input_ids[0]) + 300, 
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.3,
            no_repeat_ngram_size=3
        )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response.replace(prompt, "").strip()
        response = post_process_response(response)
        return response

    print("\nYou can now use the query_model() function to test your model interactively.")
    print("Example: response = query_model('What is the processing time for a green card application?')")

if __name__ == "__main__":
    main()

Dataset found at ./immigration_qa_dataset_clean, loading...
Train size: 340, Validation size: 42, Test size: 43
Loading tokenizer for facebook/opt-1.3b...
Loading model...
Preprocessing datasets...


Map: 100%|██████████| 340/340 [00:00<00:00, 1539.55 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 1280.10 examples/s]
Map: 100%|██████████| 43/43 [00:00<00:00, 613.27 examples/s]


Processed train dataset size: 340
Processed validation dataset size: 42
Processed test dataset size: 43
Preparing model for training...
trainable params: 12,582,912 || all params: 724,361,216 || trainable%: 1.7371045994820353


Map: 100%|██████████| 340/340 [00:00<00:00, 2065.20 examples/s]
Map: 100%|██████████| 42/42 [00:00<00:00, 2798.38 examples/s]


Starting training...


  0%|          | 0/90 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)
  3%|▎         | 3/90 [00:40<19:29, 13.44s/it]
  3%|▎         | 3/90 [00:51<19:29, 13.44s/it]

{'eval_loss': 2.643470525741577, 'eval_runtime': 1.5252, 'eval_samples_per_second': 27.538, 'eval_steps_per_second': 1.311, 'epoch': 0.8}


  return fn(*args, **kwargs)
  8%|▊         | 7/90 [01:39<19:34, 14.15s/it]
  8%|▊         | 7/90 [01:44<19:34, 14.15s/it]

{'eval_loss': 2.532104015350342, 'eval_runtime': 1.6905, 'eval_samples_per_second': 24.845, 'eval_steps_per_second': 1.183, 'epoch': 1.87}


  return fn(*args, **kwargs)
 11%|█         | 10/90 [02:21<18:27, 13.84s/it]

{'loss': 2.6269, 'learning_rate': 1.982973099683902e-05, 'epoch': 2.67}


 12%|█▏        | 11/90 [02:34<17:47, 13.52s/it]
 12%|█▏        | 11/90 [02:37<17:47, 13.52s/it]

{'eval_loss': 2.4275171756744385, 'eval_runtime': 1.7721, 'eval_samples_per_second': 23.701, 'eval_steps_per_second': 1.129, 'epoch': 2.93}


  return fn(*args, **kwargs)
 17%|█▋        | 15/90 [03:28<16:34, 13.26s/it]
 17%|█▋        | 15/90 [03:30<16:34, 13.26s/it]

{'eval_loss': 2.342378616333008, 'eval_runtime': 1.5153, 'eval_samples_per_second': 27.717, 'eval_steps_per_second': 1.32, 'epoch': 4.0}


  return fn(*args, **kwargs)
 20%|██        | 18/90 [04:11<16:34, 13.81s/it]
 20%|██        | 18/90 [04:24<16:34, 13.81s/it]

{'eval_loss': 2.28171706199646, 'eval_runtime': 1.7841, 'eval_samples_per_second': 23.541, 'eval_steps_per_second': 1.121, 'epoch': 4.8}


  return fn(*args, **kwargs)
 22%|██▏       | 20/90 [04:45<17:30, 15.01s/it]

{'loss': 2.4507, 'learning_rate': 1.8502171357296144e-05, 'epoch': 5.33}


 24%|██▍       | 22/90 [05:13<16:21, 14.44s/it]
 24%|██▍       | 22/90 [05:18<16:21, 14.44s/it]

{'eval_loss': 2.2095980644226074, 'eval_runtime': 1.7918, 'eval_samples_per_second': 23.441, 'eval_steps_per_second': 1.116, 'epoch': 5.87}


  return fn(*args, **kwargs)
 29%|██▉       | 26/90 [06:08<14:41, 13.78s/it]
 29%|██▉       | 26/90 [06:11<14:41, 13.78s/it]

{'eval_loss': 2.135554313659668, 'eval_runtime': 1.571, 'eval_samples_per_second': 26.734, 'eval_steps_per_second': 1.273, 'epoch': 6.93}


  return fn(*args, **kwargs)
 33%|███▎      | 30/90 [07:02<13:21, 13.35s/it]

{'loss': 2.3074, 'learning_rate': 1.6026346363792565e-05, 'epoch': 8.0}



 33%|███▎      | 30/90 [07:04<13:21, 13.35s/it]

{'eval_loss': 2.05501651763916, 'eval_runtime': 1.5221, 'eval_samples_per_second': 27.593, 'eval_steps_per_second': 1.314, 'epoch': 8.0}


  return fn(*args, **kwargs)
 37%|███▋      | 33/90 [07:45<13:11, 13.89s/it]
 37%|███▋      | 33/90 [07:57<13:11, 13.89s/it]

{'eval_loss': 1.9921693801879883, 'eval_runtime': 1.7772, 'eval_samples_per_second': 23.633, 'eval_steps_per_second': 1.125, 'epoch': 8.8}


  return fn(*args, **kwargs)
 41%|████      | 37/90 [08:45<12:39, 14.33s/it]
 41%|████      | 37/90 [08:50<12:39, 14.33s/it]

{'eval_loss': 1.9160311222076416, 'eval_runtime': 1.5673, 'eval_samples_per_second': 26.798, 'eval_steps_per_second': 1.276, 'epoch': 9.87}


  return fn(*args, **kwargs)
 44%|████▍     | 40/90 [09:27<11:42, 14.05s/it]

{'loss': 2.1746, 'learning_rate': 1.2736629900720832e-05, 'epoch': 10.67}


 46%|████▌     | 41/90 [09:40<11:15, 13.78s/it]
 46%|████▌     | 41/90 [09:43<11:15, 13.78s/it]

{'eval_loss': 1.864976406097412, 'eval_runtime': 1.618, 'eval_samples_per_second': 25.958, 'eval_steps_per_second': 1.236, 'epoch': 10.93}


  return fn(*args, **kwargs)
 50%|█████     | 45/90 [10:35<10:02, 13.38s/it]
 50%|█████     | 45/90 [10:37<10:02, 13.38s/it]

{'eval_loss': 1.8212977647781372, 'eval_runtime': 1.7587, 'eval_samples_per_second': 23.882, 'eval_steps_per_second': 1.137, 'epoch': 12.0}


  return fn(*args, **kwargs)
 53%|█████▎    | 48/90 [11:18<09:44, 13.90s/it]
 53%|█████▎    | 48/90 [11:30<09:44, 13.90s/it]

{'eval_loss': 1.814617395401001, 'eval_runtime': 1.6493, 'eval_samples_per_second': 25.465, 'eval_steps_per_second': 1.213, 'epoch': 12.8}


  return fn(*args, **kwargs)
 56%|█████▌    | 50/90 [11:52<10:01, 15.05s/it]

{'loss': 2.0727, 'learning_rate': 9.07731640536698e-06, 'epoch': 13.33}


 58%|█████▊    | 52/90 [12:18<08:52, 14.01s/it]
 58%|█████▊    | 52/90 [12:23<08:52, 14.01s/it]

{'eval_loss': 1.8092262744903564, 'eval_runtime': 1.5285, 'eval_samples_per_second': 27.477, 'eval_steps_per_second': 1.308, 'epoch': 13.87}


  return fn(*args, **kwargs)
 62%|██████▏   | 56/90 [13:12<07:40, 13.53s/it]
 62%|██████▏   | 56/90 [13:16<07:40, 13.53s/it]

{'eval_loss': 1.8058291673660278, 'eval_runtime': 2.0522, 'eval_samples_per_second': 20.466, 'eval_steps_per_second': 0.975, 'epoch': 14.93}


  return fn(*args, **kwargs)
 67%|██████▋   | 60/90 [14:07<06:39, 13.32s/it]

{'loss': 2.024, 'learning_rate': 5.542616442234618e-06, 'epoch': 16.0}



 67%|██████▋   | 60/90 [14:09<06:39, 13.32s/it]

{'eval_loss': 1.8030222654342651, 'eval_runtime': 1.4225, 'eval_samples_per_second': 29.526, 'eval_steps_per_second': 1.406, 'epoch': 16.0}


  return fn(*args, **kwargs)
 70%|███████   | 63/90 [14:49<06:09, 13.68s/it]
 70%|███████   | 63/90 [15:02<06:09, 13.68s/it]

{'eval_loss': 1.8012328147888184, 'eval_runtime': 1.8131, 'eval_samples_per_second': 23.165, 'eval_steps_per_second': 1.103, 'epoch': 16.8}


  return fn(*args, **kwargs)
 74%|███████▍  | 67/90 [15:51<05:30, 14.39s/it]
 74%|███████▍  | 67/90 [15:56<05:30, 14.39s/it]

{'eval_loss': 1.7986561059951782, 'eval_runtime': 1.9015, 'eval_samples_per_second': 22.088, 'eval_steps_per_second': 1.052, 'epoch': 17.87}


  return fn(*args, **kwargs)
 78%|███████▊  | 70/90 [16:34<04:45, 14.26s/it]

{'loss': 2.0244, 'learning_rate': 2.6099108277934105e-06, 'epoch': 18.67}


 79%|███████▉  | 71/90 [16:46<04:22, 13.79s/it]
 79%|███████▉  | 71/90 [16:49<04:22, 13.79s/it]

{'eval_loss': 1.7970116138458252, 'eval_runtime': 1.8109, 'eval_samples_per_second': 23.192, 'eval_steps_per_second': 1.104, 'epoch': 18.93}


  return fn(*args, **kwargs)
 83%|████████▎ | 75/90 [17:41<03:22, 13.47s/it]
 83%|████████▎ | 75/90 [17:43<03:22, 13.47s/it]

{'eval_loss': 1.7960962057113647, 'eval_runtime': 1.6475, 'eval_samples_per_second': 25.493, 'eval_steps_per_second': 1.214, 'epoch': 20.0}


  return fn(*args, **kwargs)
 87%|████████▋ | 78/90 [18:25<02:46, 13.89s/it]
 87%|████████▋ | 78/90 [18:37<02:46, 13.89s/it]

{'eval_loss': 1.7956740856170654, 'eval_runtime': 1.8622, 'eval_samples_per_second': 22.555, 'eval_steps_per_second': 1.074, 'epoch': 20.8}


  return fn(*args, **kwargs)
 89%|████████▉ | 80/90 [18:57<02:28, 14.81s/it]

{'loss': 1.9872, 'learning_rate': 6.752777059564431e-07, 'epoch': 21.33}


 91%|█████████ | 82/90 [19:25<01:54, 14.29s/it]
 91%|█████████ | 82/90 [19:30<01:54, 14.29s/it]

{'eval_loss': 1.7953498363494873, 'eval_runtime': 1.6008, 'eval_samples_per_second': 26.237, 'eval_steps_per_second': 1.249, 'epoch': 21.87}


  return fn(*args, **kwargs)
 96%|█████████▌| 86/90 [20:20<00:55, 13.81s/it]
 96%|█████████▌| 86/90 [20:23<00:55, 13.81s/it]

{'eval_loss': 1.7952032089233398, 'eval_runtime': 1.6013, 'eval_samples_per_second': 26.228, 'eval_steps_per_second': 1.249, 'epoch': 22.93}


  return fn(*args, **kwargs)
100%|██████████| 90/90 [21:15<00:00, 13.14s/it]

{'loss': 1.996, 'learning_rate': 0.0, 'epoch': 24.0}



100%|██████████| 90/90 [21:16<00:00, 13.14s/it]

{'eval_loss': 1.7951714992523193, 'eval_runtime': 1.6378, 'eval_samples_per_second': 25.644, 'eval_steps_per_second': 1.221, 'epoch': 24.0}


100%|██████████| 90/90 [21:16<00:00, 14.19s/it]


{'train_runtime': 1276.9845, 'train_samples_per_second': 7.988, 'train_steps_per_second': 0.07, 'train_loss': 2.1848822911580403, 'epoch': 24.0}
Model saved to ./immigration_assistant_gemma_final

Preparing for comprehensive evaluation...
Loading base model for comparison...


  return torch.load(checkpoint_file, map_location="cpu")


Loading fine-tuned model...


  adapters_weights = torch.load(



Evaluating Base Model...

Evaluating Fine-tuned Model...

--- Comprehensive Evaluation Results ---

Base Model Metrics:
rouge_rouge1: 0.25512243234592313
rouge_rouge2: 0.03536031362850581
rouge_rougeL: 0.11965858769181872
rouge_rougeLsum: 0.12181364238077827
tfidf_cosine_similarity: 0.2294454273578281
bleu: 0.012305708975260607
perplexity: 24.340541513230185

Fine-tuned Model Metrics:
rouge_rouge1: 0.24615175872918588
rouge_rouge2: 0.03461837850739836
rouge_rougeL: 0.11328537262858526
rouge_rougeLsum: 0.11805270888834501
tfidf_cosine_similarity: 0.2007021213417432
bleu: 0.012487286713355066
perplexity: 24.340541513230185

Saved metrics results to ./gemma_model_evaluation_metrics.csv
Saved comparison results to ./gemma_model_comparison_results.csv

--- Example Outputs ---


--- Example 1 ---
Question: I received a Notice of Intent to Deny (NOID) my case from the government. What can I do?
Reference Answer: Many times the government improperly concludes that a case is deniable. Our expe