# Section 4: Fine-Tuning using Saul LM 7B
This section involves the fine-tuning of the Saul LM 7B model on your dataset.
The model will be trained further on the preprocessed dataset to adapt to the specific task at hand.

In [None]:
# Import necessary libraries and modules
from datasets import Dataset, DatasetDict
import torch
import json
from tqdm import tqdm
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import accelerate
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime
import os
from peft import LoraConfig, get_peft_model
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
import evaluate
import numpy as np

# Function to load JSON data line by line
def load_json_lines(filename):
    data = []
    with open(filename, 'r') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

# Load the JSON data containing case details
f_cases_data = load_json_lines('./data/nia_cases_to_process.json')

# Sort the cases by JudgmentDate in ascending order
sorted_cases = sorted(f_cases_data, key=lambda x: datetime.strptime(x['JudgmentDate'], '%d/%m/%Y'))

# Load prompt template from files
with open("./prompts/prompt_with_nocontext.txt", "r") as f:
    prompt_with_nocontext = f.read().strip()

# Function to format the prompt using the no-context template
def format_prompt_nocontext(doc):
    return prompt_with_nocontext.format(
        Case_Inputs=json.dumps(doc.get('Case_Inputs', ''), indent=1)[2:-2]
    )

# Model and Tokenizer Setup
access_token = "Enter your hugging face token" 
model_id = "Equall/Saul-7B-Instruct-v1" 

# Configure the model for 4-bit quantization for efficient memory usage
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # 4-bit quantization for better memory efficiency
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load the model with the specified configuration
hf_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config, 
    torch_dtype=torch.bfloat16,
    device_map="auto",  
    use_cache=False, 
    token=access_token 
)

# Enable gradient checkpointing for memory optimization
hf_model.gradient_checkpointing_enable()  
torch.cuda.empty_cache() 

# Load the tokenizer associated with the model
hf_tokenizer = AutoTokenizer.from_pretrained(model_id)

# Set the EOS token as the padding token
hf_tokenizer.pad_token = hf_tokenizer.eos_token
hf_tokenizer.pad_token_id = hf_tokenizer.eos_token_id

# Split the data into training (80%), validation (10%), and test (10%) sets
train_data, temp_data = sorted_cases[:int(0.8*len(sorted_cases))], sorted_cases[int(0.8*len(sorted_cases)):]

# Further split the remaining data into validation and test sets
val_data, test_data = temp_data[:int(0.5*len(temp_data))], temp_data[int(0.5*len(temp_data)):]

# Print the sizes of the training, validation, and test datasets
print(f"Training data: {len(train_data)} cases")
print(f"Validation data: {len(val_data)} cases")
print(f"Test data: {len(test_data)} cases")

# Convert the data into a DatasetDict format
ft_dataset = DatasetDict({
    'train': Dataset.from_list(train_data),
    'validation': Dataset.from_list(val_data),
    'test': Dataset.from_list(test_data),
})

# Extract judgment dates for visualization
train_dates = [datetime.strptime(case['JudgmentDate'], '%d/%m/%Y') for case in train_data]
val_dates = [datetime.strptime(case['JudgmentDate'], '%d/%m/%Y') for case in val_data]
test_dates = [datetime.strptime(case['JudgmentDate'], '%d/%m/%Y') for case in test_data]

# Plot the distribution of cases over time across training, validation, and test sets
plt.figure(figsize=(6, 4))

plt.hist(train_dates, bins=100, alpha=0.6, label=f'Train ({len(train_data)})', color='skyblue')
plt.hist(val_dates, bins=10, alpha=0.7, label=f'Validation ({len(val_data)})', color='orange')
plt.hist(test_dates, bins=10, alpha=0.7, label=f'Test ({len(test_data)})', color='orange')

plt.xlabel('Judgment Date')
plt.ylabel('Number of Cases')
plt.title('Distribution of Cases on Time Axis')

plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# Set environment variable for memory allocation configuration
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# Check and define the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set the padding token in the model configuration
hf_model.config.pad_token_id = hf_tokenizer.pad_token_id

# Set up LoRA (Low-Rank Adaptation) configuration for efficient fine-tuning
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["k_proj", "q_proj", "v_proj", "o_proj"] 
)

# Apply the LoRA configuration to the model
peft_model = get_peft_model(hf_model, peft_config)
peft_model.train()

torch.cuda.empty_cache()

# Function to tokenize the dataset for training
def tokenize_function(case):
    combined_input = format_prompt_nocontext(case)
    result = hf_tokenizer(combined_input, truncation=True, max_length=4096, return_overflowing_tokens=True)
    
    # Handling labels similarly to inputs
    labels = hf_tokenizer(case['Case_Result'], truncation=True, max_length=4096, return_overflowing_tokens=True)
    
    # Map the labels to the correct overflowed input chunks
    sample_map = result.pop("overflow_to_sample_mapping")
    for key, values in labels.items():
        result[key] = [values[i] for i in sample_map]
    
    result["labels"] = result["input_ids"]  
    return result

# Tokenize the entire dataset using the defined tokenize_function
tokenized_datasets = ft_dataset.map(tokenize_function, batched=True, remove_columns=ft_dataset['train'].column_names)

torch.cuda.empty_cache()

# Data collator for dynamic padding during training
data_collator = DataCollatorForSeq2Seq(tokenizer=hf_tokenizer, model=hf_model, padding=True)

# Load evaluation metrics
bertscore_metric = evaluate.load("bertscore")
rouge_metric = evaluate.load("rouge")
meteor_metric = evaluate.load("meteor")

# Function to compute evaluation metrics
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    decoded_preds = hf_tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = hf_tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Clean the generated predictions
    cleaned_preds = [pred.split("[/INST]")[-1].strip() for pred in decoded_preds]   # Removing the prompt section from the generated result
    
    # Compute different evaluation metrics
    bertscore_results = bertscore_metric.compute(predictions=cleaned_preds, references=decoded_labels, lang="en")
    rouge_results = rouge_metric.compute(predictions=cleaned_preds, references=decoded_labels)
    meteor_results = meteor_metric.compute(predictions=cleaned_preds, references=decoded_labels)
    
    bertscore_f1 = np.mean(bertscore_results['f1'])
    rouge_l_f1 = rouge_results['rougeL']
    meteor_score = meteor_results['meteor']
    
    # Weighted score for overall evaluation
    weights = {'bertscore': 0.1, 'rouge': 0.6, 'meteor': 0.3}
    weighted_score = (weights['bertscore'] * bertscore_f1 +
                      weights['rouge'] * rouge_l_f1 +
                      weights['meteor'] * meteor_score)
    
    return {
        'bertscore_f1': bertscore_f1,
        'rouge_l_f1': rouge_l_f1,
        'meteor': meteor_score,
        'weighted_score': weighted_score
    }

# Define training arguments
training_arguments = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps=25,
    save_strategy="epoch",
    logging_steps=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    optim="paged_adamw_8bit",
    num_train_epochs=4,
    weight_decay=0.01,
    logging_dir="./logs",
    learning_rate=1e-5,
    bf16=True,
    max_grad_norm=1.0,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    ddp_find_unused_parameters=False
)

# Initialize the Trainer object
trainer = Trainer(
    model=peft_model,
    args=training_arguments,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

torch.cuda.empty_cache()

# Evaluate the model before training
trainer.evaluate()

torch.cuda.empty_cache()

# Start the training process
trainer.train()

# Save the fine-tuned model and tokenizer
peft_model.save_pretrained("./fine_tuned_model_v2")
hf_tokenizer.save_pretrained("./fine_tuned_model_v2")

In [None]:
# Inference: Using the fine-tuned model to generate predictions

# Ensuring torch device is defined
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Load the tokenizer from the fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("./fine_tuned_model_v2")

# Load the fine-tuned model without moving it automatically to a specific device
fine_tuned_model = AutoModelForCausalLM.from_pretrained(
    "./fine_tuned_model_v2",
    torch_dtype=torch.bfloat16,  # Use bfloat16 to reduce memory footprint
    device_map="auto",  # Using Accelerate to manage device mapping
    low_cpu_mem_usage=True  # Minimizes CPU memory usage during loading
)

# Function to generate analysis for a single case using the fine-tuned model
def generate_single_analysis(case, model, tokenizer, device):
    prompt = format_prompt_nocontext(case)
    tokens = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=4096)
    
    # Ensure to send both input_ids and attention_mask to the correct device
    tokens = {key: value.to(fine_tuned_model.device) for key, value in tokens.items()}

    # Generate text while passing attention_mask and setting pad_token_id explicitly
    generated_text = fine_tuned_model.generate(
        input_ids=tokens['input_ids'],
        attention_mask=tokens['attention_mask'],  
        max_new_tokens=4096,
        do_sample=True,
        temperature=0.7,
        use_cache=True
    )
    
    return tokenizer.decode(generated_text[0], skip_special_tokens=True)

results=[]
references = []
generated_texts = []

# Perform inference on the test data
for case in tqdm(test_data, desc='Generating Analysis'):
    analysis = generate_single_analysis(case, fine_tuned_model, tokenizer, device).split("[/INST]")[-1].strip()
    case['Fine_Tuned_Generated_Analysis'] = analysis
    results.append(case)
    
    # Storing reference and generated texts for evaluation
    references.append(case['Case_Result'])
    generated_texts.append(analysis)

# Save the generated analyses to a JSON file
with open('./results/fine_tuned_generated_analyses_v2.json', 'w') as f:
    json.dump(results, f, indent=2)

print("Analysis Generated and Saved with Fine-Tuned Model")