# Section 2: Prompting using DSpy
This section will include the functions or methods needed for creating prompts using DSpy.
The prompts are created to structure the input data before passing it to the model.

In [None]:
#!pip install dspy-ai

import dspy
import torch
import json
import os
from dspy.evaluate import Evaluate
from dspy.teleprompt import MIPRO
from bert_score import score as bert_score
from rouge_score import rouge_scorer
from transformers import BitsAndBytesConfig
import accelerate
from datetime import datetime

# Function to load JSON data from a file
def load_json(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    return data

# Load the JSON data containing the case details
cases_data = load_json('./data/cleaned_nia_cases.json')

# Set up quantization configuration for model loading to improve memory efficiency
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # 4-bit quantization for better memory efficiency
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,  # Use bfloat16 for computations
    bnb_4bit_use_double_quant=True,  # Double quantization for improved accuracy
)

# Initialize the HF model using DSPy HFModel class with quantization settings
dspy_model = dspy.HFModel(
    model='mistralai/Mixtral-8x7B-Instruct-v0.1',  # Model ID
    model_kwargs={
        "quantization_config": quantization_config,  # Apply the quantization settings
        "device_map": "auto",  # Automatically distribute model across available devices
        "token": "Enter your HF token"  # Access token for Hugging Face API
    }
)

# Configure DSPy with the initialized HF model
dspy.settings.configure(lm=dspy_model)

# Set up the device for PyTorch operations 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
torch.cuda.memory_summary()

# Define a custom evaluation metric function combining BERTScore and ROUGE
def custom_metric_evaluation(original, generated, trace=None):
    
    # Calculate BERTScore
    P, R, F1 = bert_score([original.petition_verdict], [generated.petition_verdict], lang="en", verbose=True,device=device)

    # Calculate ROUGE-L score
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(original.petition_verdict, generated.petition_verdict)
    rouge_f1 = scores['rougeL'].fmeasure  # Using ROUGE-L F1 score

    # Combine BERTScore F1 and ROUGE-L F1 with weighted average
    combined_score = 0.15*F1.mean().item() + 0.85*rouge_f1
    return combined_score

# Prepare input text by converting list input to a single string if necessary
def prepare_text_input(input_value):
    if isinstance(input_value, list):
        return " ".join(input_value)  # Convert list to a single string
    return input_value  

# Guidelines for generating legal analysis, provided as context to the model
guidelines = """
You are a seasoned Indian High Court judge specializing in banking law, particularly the Indian Negotiable Instruments Act.
Follow these guidelines:
1. Analyze the legal arguments concisely.
2. Refer to past cases and applicable laws.
3. Provide a final verdict, and conclude with a single word: 'ALLOWED,' 'DISMISSED,' or 'OTHER.'
"""

# Prepare the dataset by formatting each case according to the guidelines and inputs
dataset = [
    dspy.Example(
        legal_case_facts=prepare_text_input(case['Facts']),
        arguments_by_respondent=prepare_text_input(case['Arg Resp']),
        argument_by_petitioner=prepare_text_input(case['Arg_Pet']),
        ruling_by_lower_court=prepare_text_input(case['RLC']),
        petition_verdict=prepare_text_input(case['Analysis']),
        guidelines=guidelines
    ).with_inputs("legal_case_facts", "arguments_by_respondent", "argument_by_petitioner", "ruling_by_lower_court", "guidelines")
    for case in cases_data[:5]  # Limit to first 5 cases due to memory constraints
]

# Define the chain of thought model input and output mapping
LegalAnalysis = dspy.ChainOfThought("guidelines,legal_case_facts,arguments_by_respondent,argument_by_petitioner,ruling_by_lower_court -> petition_verdict")

# Use MIPRO optimizer for prompt tuning with specified configuration
config = dict(num_candidates=4)
optimizer = MIPRO(metric=custom_metric_evaluation, **config)
kwargs = dict(num_threads=1, display_progress=True, display_table=0)

# Compile the program with optimization and evaluation settings
compiled_program = optimizer.compile(
    LegalAnalysis,
    trainset=dataset, 
    num_trials=3,
    max_bootstrapped_demos=2,
    max_labeled_demos=2,
    eval_kwargs=kwargs,
    requires_permission_to_run=False,
    view_data=False,
    view_examples=False
)

# Save the optimized prompt configuration for future use
compiled_program.save("./results/optimized_promptv2")
print(compiled_program)

print("Analysis Generated and Saved with Optimized Prompts")