In [1]:
import transformers
import torch
import json

# Load the model and tokenizer
model_id = "meta-llama/Llama-3.1-8B-Instruct"
pipeline = transformers.pipeline(
    "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto",
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:

# Step 1: Generate Reasoning Steps from Context, Question, and Options
def generate_reasoning_steps(context, question, options):
    # Define the prompt for the first LLM instance to generate reasoning steps and premises
    task_description = """
Task:
You are a helpful legal assistant. Choose the correct option by performing legal reasoning while strictly adhering to the legal context below. 
The reasoning should start by generating the question-related premises first, which are the statements and relationships given in the Question and Legal Context. 
Then based on these, generate the reasoning process in steps. Each reasoning step should:
1. Be logically consistent and based on the available premises.
2. Cite the premises used to support each reasoning step in the format: {QC #}, where # refers to the premise number.
3. Clearly explain how each step leads to the final answer.
4. The reasoning steps should be in a sequential, deductive chain, with each step based only on the premises and steps generated previously.

Please ensure that the output is clearly separated by special markers so that it can be parsed into individual elements later on.

---


Start with the Question-Context steps, followed by the Reasoning steps. Use the following structure:

### Question-Context Steps:
{QC 1} statement or relationship given/inferred in the Question and Legal Context.
{QC 2} statement or relationship given/inferred in the Question and Legal Context.
{QC 3} statement or relationship given/inferred in the Question and Legal Context.

---

### Reasoning Steps:
Step 1: [Premise-based reasoning derived from {QC 1}, {QC 2}, etc.]
Step 2: [Next reasoning step derived from previous steps and citing relevant premises]
...
Step N: [Final reasoning step and conclusion based on the chain of reasoning]
"""

    # Prepare the prompt
    prompt = task_description + f"\nContext: {context}\nQuestion: {question}\nOptions: {options}"

    # Generate the reasoning steps
    result = pipeline(prompt, max_length=4000, num_return_sequences=1)
    reasoning_steps = result[0]['generated_text']
    
    # Parse the reasoning steps and question-context steps
    question_context_steps = []
    reasoning_step_list = []
    
    # Parsing the output
    # Split based on the known markers to extract question-context and reasoning steps
    question_context_marker = "### Question-Context Steps:"
    reasoning_steps_marker = "### Reasoning Steps:"

    # Extracting Question-Context Steps
    question_context_section = reasoning_steps.split(question_context_marker)[2].split(reasoning_steps_marker)[0].strip()
    for line in question_context_section.split("\n"):
        if line.strip():
            question_context_steps.append(line.strip())
    
    # Extracting Reasoning Steps
    reasoning_steps_section = reasoning_steps.split(reasoning_steps_marker)[2].strip()
    for line in reasoning_steps_section.split("\n"):
        if line.strip():
            reasoning_step_list.append(line.strip())
    
    preliminary_information = f"\nContext: {context}\nQuestion: {question}\nOptions: {options}"
    
    return question_context_steps, reasoning_step_list, reasoning_steps, preliminary_information
    

In [3]:
# Step 2: Validate Each Reasoning Step
def validate_reasoning_step(question_context_steps, reasoning_steps, preliminary_information, step_number):
    """
    # Extract the reasoning step number and cited premises
    step_number = int(reasoning_step.split(":")[0].replace("Step", "").strip())
    reasoning_text = reasoning_step.split(":")[1].strip()
    
    # Get the premises for this reasoning step based on citations
    premises_cited = []
    for line in reasoning_step.split("Citations:")[1:]:
        cited_premises = line.strip().split(",")
        premises_cited.extend([p.strip() for p in cited_premises])
    """
    
    # Construct the validation prompt using f-string and proper concatenation
    validation_prompt = f"""
Task:
You are a legal expert. Validate the following reasoning step. Based on the given context, question, and options and other reasoning steps until then, determine if the reasoning step and its premises are logically sound and derivable from the provided information.

{preliminary_information}

### Question-Context Premises:
{chr(10).join(question_context_steps)}

### Previous Reasoning Premises: 
{chr(10).join(reasoning_steps[:step_number])}

### Reasoning Step:
{reasoning_steps[step_number]}

Does this reasoning step follow logically from the context, question, options and other given premises information? Answer 'Yes' or 'No', and explain why.
"""

    # Validate the reasoning step
    result = pipeline(validation_prompt, max_length=4000, num_return_sequences=1)
    validation_output = result[0]['generated_text']
    
    # Check if the validation response is "Yes" or "No"
    if "Yes" in validation_output:
        return True
    else:
        return False


In [4]:

# Step 3: Process Dataset and Generate Prompts for Each Sample
def process_sample(sample):
    context = sample['context']
    question = sample['question']
    options = sample['options']

    # Generate reasoning steps for this sample
    question_context_steps, reasoning_steps, generated_text, preliminary_information = generate_reasoning_steps(context, question, options)

    # Validate each reasoning step sequentially
    reasoning_validations = []
    for step_number in range(len(reasoning_steps)):
        # Validate this reasoning step
        is_valid = validate_reasoning_step(question_context_steps, reasoning_steps, preliminary_information, step_number)
        reasoning_validations.append(is_valid)

    return reasoning_validations, reasoning_steps, generated_text, question_context_steps


In [None]:

# Main function to load data, process and validate reasoning steps for all samples
def main():
    # Load the JSON dataset
    with open('legal_reasoning_30_sample.json', 'r') as file:
        data = json.load(file)

    # Process each sample and validate reasoning steps
    for sample in data['legal_scenarios']:
        reasoning_validations, reasoning_steps, generated_text, question_context_steps = process_sample(sample)

        # Add the generated data to the sample
        sample['reasoning_validations'] = reasoning_validations
        sample['reasoning_steps'] = reasoning_steps
        sample['generated_text'] = generated_text
        sample['question_context_steps'] = question_context_steps

    # Save the updated data into a new JSON file
    with open('processed_legal_reasoning_samples.json', 'w') as output_file:
        json.dump(data, output_file, indent=4)

if __name__ == "__main__":
    main()
