Use Case: Build a Legal AI Assistant that classifies legal complaint texts by potential violation type (e.g., labor, human rights, environmental), and explains the classification using natural language reasoning.



In [None]:
# STEP 1: Install the Gemini SDK
!pip install -q google-generativeai

# STEP 2: Load  Gemini API Key
import os
import google.generativeai as genai

# Replace YOUR_API_KEY with your actual Gemini API key
GEMINI_API_KEY = "AIzaSyC5LBQFarxRwaadGBKywwig0UGJLpFOMFY"
os.environ["GOOGLE_API_KEY"] = GEMINI_API_KEY
genai.configure(api_key=GEMINI_API_KEY)

# STEP 3: Initialize the Gemini Pro model
model = genai.GenerativeModel("gemini-1.5-flash")

# STEP 4: Test the model
response = model.generate_content("Hello, what are you?")
print(response.text)


I am a large language model, trained by Google.  I'm an AI designed to process information and respond to a wide range of prompts and questions.  Essentially, I'm a computer program that can communicate and generate human-like text.



In [7]:
test_cases = [
    {
        "id": "001",
        "case_text": "An employee was hired for a 2-year contract but was terminated without cause after 6 months.",
        "context": "The contract stipulated early termination requires a 3-month severance.",
        "expected": "Breach of employment contract"
    },
    {
        "id": "002",
        "case_text": "A blogger accused a public figure of financial fraud without presenting any evidence.",
        "context": "The public figure suffered reputational damage and financial losses.",
        "expected": "Defamation (libel)"
    },
    {
        "id": "003",
        "case_text": "A tech company copied key features from a patented software system without licensing.",
        "context": "The original company filed a lawsuit under U.S. patent law.",
        "expected": "Patent infringement"
    },
    {
        "id": "004",
        "case_text": "A customer slipped on a wet floor in a grocery store. There was no warning sign.",
        "context": "The customer suffered a broken arm and sued for negligence.",
        "expected": "Premises liability due to negligence"
    },
    {
        "id": "005",
        "case_text": "A seller misrepresented the condition of a used car as 'like new,' hiding previous accident damage.",
        "context": "The buyer discovered the issue after purchase and sued for damages.",
        "expected": "Fraudulent misrepresentation"
    },
]


In [8]:
def generate_cot_prompt(case_text, context):
    return f"""
You are a legal expert tasked with analyzing a legal document and identifying any potential violations.

Case: {case_text}
Context: {context}

**Instructions**:
- First, classify the violation (e.g., breach of contract, personal injury, etc.).
- Then, explain step-by-step why you think this is the violation.
- Provide relevant legal clauses or precedents to support your reasoning.

Please answer the following questions:
1. What violation has occurred in this case?
2. Why do you classify this as [violation type]?
3. What are the relevant legal clauses or precedents that apply to this violation?

**Reasoning (Chain-of-Thought)**:
1. Start by analyzing the case details step by step.
2. Provide intermediate conclusions.
3. Conclude with the final classification and reasoning.
"""


In [9]:
def classify_violation_with_cot(case_text, context):
    """
    Classify a legal violation and provide reasoning using Chain-of-Thought.
    """
    # Generate the prompt using CoT
    cot_prompt = generate_cot_prompt(case_text, context)

    # Use Gemini to generate the content
    response = model.generate_content(cot_prompt)

    # Extract the text from the response
    classification_and_reasoning = response.text
    return classification_and_reasoning


In [10]:
from IPython.display import Markdown, display
import google.generativeai as genai

def test_legal_scenarios(test_cases, model):
    markdown_table = "| Case ID | Expected | Model Output | Match | Notes |\n"
    markdown_table += "|---------|----------|--------------|-------|-------|\n"

    for case in test_cases:
        prompt = generate_cot_prompt(case["case_text"], case["context"])
        try:
            response = model.generate_content(prompt)
            output = response.text.strip()
            first_line = output.split('\n')[0]
            match = case["expected"].lower() in output.lower()
            markdown_table += f"| {case['id']} | {case['expected']} | {first_line} | {'✅' if match else '❌'} | {'-' if match else 'Check classification'} |\n"
        except Exception as e:
            markdown_table += f"| {case['id']} | {case['expected']} | ERROR | ❌ | {str(e)[:40]}... |\n"

    display(Markdown(markdown_table))


In [11]:
test_legal_scenarios(test_cases, model)


| Case ID | Expected | Model Output | Match | Notes |
|---------|----------|--------------|-------|-------|
| 001 | Breach of employment contract | **Reasoning (Chain-of-Thought):** | ✅ | - |
| 002 | Defamation (libel) | **Reasoning (Chain-of-Thought):** | ✅ | - |
| 003 | Patent infringement | ERROR | ❌ | Invalid operation: The `response.text` q... |
| 004 | Premises liability due to negligence | **1. What violation has occurred in this case?** | ❌ | Check classification |
| 005 | Fraudulent misrepresentation | **Chain-of-Thought:** | ✅ | - |


In [None]:
def build_constrained_prompt(scenario_text, labels):
    # Join the list of labels into a comma-separated string, with each label wrapped in double quotes
    label_list = ', '.join([f'"{label}"' for label in labels])
    
    # Return the complete prompt
    return f"""
Given the legal scenario below, classify it into one of the following categories:
{label_list}

Scenario: "{scenario_text}"

Only return the classification label. Do not explain.
"""


In [None]:
from collections import Counter

def classify_with_self_consistency(case_text, model, label_options, num_responses=5):
    # Create the prompt based on the case text and possible labels
    prompt = build_constrained_prompt(case_text, label_options)
    
    # Collect responses from the model
    responses = [model.generate_content(prompt).text.strip().lower() for _ in range(num_responses)]

    # Filter out responses that are not in the predefined labels
    responses = [resp for resp in responses if resp in [label.lower() for label in label_options]]
    
    # If no valid responses, return "unclear"
    if not responses:
        return "unclear"
    
    # Return the most frequent response using majority voting
    final_prediction = Counter(responses).most_common(1)[0][0]
    return final_prediction


In [None]:
from sklearn.metrics import classification_report
from IPython.display import display, Markdown

def evaluate_model_with_self_consistency(test_cases, model, label_options, num_responses=5):
    y_true = []  # List to store the true labels (expected results)
    y_pred = []  # List to store the predicted labels (model outputs)
    
    markdown_table = "| Case ID | Expected | Predicted | Match |\n|---------|----------|-----------|-------|"
    
    # Iterate over each test case
    for case in test_cases:
        expected = case['expected'].strip().lower()  # True label (in lowercase)
        
        # Get the prediction from the model using self-consistency
        prediction = classify_with_self_consistency(case['input'], model, label_options, num_responses)
        
        y_true.append(expected)  # Append the expected result
        y_pred.append(prediction)  # Append the predicted result

        # Check if the prediction matches the expected result
        match = "✅" if prediction == expected else "❌"
        
        # Append the results to the markdown table
        markdown_table += f"\n| {case['id']} | {expected} | {prediction} | {match} |"
    
    # Display the markdown table with results
    display(Markdown(markdown_table))
    
    # Calculate and print the classification report (precision, recall, F1-score)
    print("\nClassification Report:\n")
    print(classification_report(y_true, y_pred, zero_division=0))


In [None]:
label_options = [
    "breach of employment contract",
    "defamation (libel)",
    "patent infringement",
    "premises liability due to negligence",
    "fraudulent misrepresentation"
]


In [None]:
evaluate_model_with_self_consistency(test_cases, model, label_options, num_responses=5)
