This notebook provides the code for generating the Python code representation of legal text by prompting GPT-4o with the class structure and demonstrations. The demonstrations are selected based on a strategy involving attribute overlap and cosine similarity as described in the paper. 

Please note that executing this notebook requires an OpenAI API key. 

In [None]:
from collections import Counter
import json
import csv
import random
import pandas as pd
from openai import OpenAI
import random
import numpy as np
from langchain_openai import OpenAIEmbeddings

random.seed(42)

# Load Train and Test File

In [None]:
# load development data
df_train = pd.read_csv('path/to/train.csv') # Specify the correct path to your development data
df_train.head()

In [None]:
# print the development set tag distribution     

counter = Counter()
for iter,d in df_train.iterrows():
    tags = [d.strip()[1:-1] for d in d['tags'][1:-1].split(',')]
    counter.update(tags)
for tag, count in counter.most_common():
    print(f"{tag}: {count}")

In [None]:
# convert tags in development data to a list of strings
df_train['tags'] = df_train['tags'].apply(lambda x: [d.strip()[1:-1] for d in x[1:-1].split(',')])
df_train.head()

In [None]:
# randomly shuffle the development data set with a fixed seed
df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True)
df_train.head()


In [142]:
# create k folds of the development data set
k = 5
folds = []

for i in range(k):
    folds.append(df_train[i::k])

In [117]:
# store the folds in the development set folder
for i, fold in enumerate(folds):
    fold.to_csv(f'development_set/fold-{i}.csv', index=False)

In [None]:
# For k fold setting

# assign fold with index i as the testing set dataframe
# assign all other folds as the development set dataframe

index = 0 # Change this index to select different folds for testing
df_test = folds[index]
df_train = pd.concat(folds[:index] + folds[index+1:])

print(f"development set size: {len(df_train)}")
print(f"Testing set size: {len(df_test)}")

In [None]:
# For test file

# read the testing set
testing_set = pd.read_csv('path/to/test.csv')  # Specify the correct path to your testing data
testing_set.head()

# Tagging Test File

In [None]:
# global variables for predefined tags

predefined_tags = """
{
    '#definition': 'a legal statement defining the meaning of concepts [mean, include]',
    '#exclusion': 'a phrase highlighting what is excluded from the definition of a term [exclude, not include]',
    '#exemption': 'a legal statement that exempts someone/something from a rule [exempt, does not apply to, does not require]',
    '#obligation': 'a statement imposing mandatory action to be performed by an agent [shall, must]',
    '#permission': 'a statement indicating the possibility to perform an action without an obligation or a prohibition [may, is permitted to, can, be deemed]',
    '#prohibition': 'a statement forbidding an action to happen or take place [may not, shall not, must not]',
    '#penalty': 'a statement indicating the punishment for not following a rule',
    '#information': 'a legal statement about something that is known or proved to be true',
    '#continuation': 'denoting nested legal statements; assigned whenever a phrase contains a colon and is followed by a bullet list',
    '#condition': 'a phrase in a statement highlighting a constraint under which a rule applies [if, when, after]',
    '#follows': 'relation that connects a statement to references or other statements that precede (act as pre-conditions to) the statement [pursuant to, in accordance with, under]',
    '#refines': 'relation that connects a statement that provides more information about a reference or base statement to the reference or base statement',
    '#followed_by': 'relation that connects a statement to references or other statements that follow the statement',
    '#refined_by': 'relation that connects a base statement to a cross reference or another statement that provides more information about the base statement [as defined in, as described in]',
    '#exception': 'relation that connects a statement to references or other expressions that are exceptions to the statement [unless, except]',
    '#exception_to': 'relation that connects a statement that acts as a exception to a reference or base statement with the reference or base statement' [notwithstanding]
    '#reference': 'when the text contains pointers, numbers, or names to other sections, paragraphs, or laws'
}
"""

In [None]:
# Initialize OpenAI client
# Make sure to set your OpenAI API key in the environment variable OPENAI_API_KEY
# Alternatively, you can pass the key directly to the OpenAI constructor with `api_key='your_api_key'`

client = OpenAI()

def prompt_model(prompt):
    """
    Function to prompt the OpenAI model with a given prompt and return the response.
    Args:
        prompt (str): The input prompt to send to the model.
    Returns:
        str: The model's response.
    """
    # Call the OpenAI API to get a completion
    # Ensure you have the correct model and parameters set
    completion = client.chat.completions.create(
        model='gpt-4o',
        store=True,
        messages=[
            {'role': 'user', 'content': prompt}
        ],
        temperature=0.5
    )
    return completion.choices[0].message.content

In [None]:
# Define the prompt for the model

prompt = """Read the text and assign tags based on the definitions provided. Do not create your own tags. Only output the tags in the form of a python list. Do not include the assigned parts of the text in your response.

Tag Definitions:
%s

Text: %s
Tags: """


def exec_prompt(text):
    """
    Function to execute the prompt with the given text and return the model's response.
    Args:
        text (str): The input text to analyze.
    Returns:
        str: The model's response containing the assigned tags.
    """
    # Format the prompt with predefined tags and the input text
    p = prompt % (predefined_tags, text)
    # Call the model with the formatted prompt
    a = prompt_model(p)
    return a

In [None]:
# Prompt the model for each sample in the testing set and collect the answers

print('Prompting for %i test samples' % len(testing_set), end='')
answers = []
for i in range(len(testing_set)):
    # Execute the prompt for each text in the testing set
    a = exec_prompt(testing_set.iloc[i]['text'])
    answers.append(a)
    print('. ', end='')
print()

In [None]:
def extract_tags(answer):
    """
    Function to extract tags from the model's response.
    Args:
        answer (str): The model's response containing the tags.
    Returns:
        list: A list of unique tags extracted from the response.
    """
    if answer[:10] == '```python\n':
        tags = answer[11:-5].split(',')
    else:
        tags = answer[1:-1].split(',')

    tags = [tag.strip()[1:-1] for tag in tags]
    return list(set(tags))

# Extract tags from the model's answers and store them in the testing set
testing_set['tags'] = [extract_tags(a) for a in answers]

# Demonstration Selection

In [None]:
# Compute embeddings for the development set
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

def cosine_similarity(v1, v2):
    """
    Function to compute the cosine similarity between two vectors.
    Args:
        v1 (list or np.array): First vector.
        v2 (list or np.array): Second vector.
    Returns:
        float: Cosine similarity between the two vectors.
    """
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))

def select_demos(test_sample, demos, n=5):
    """
    Function to select demonstrations based on the defined rules.
    Args:
        test_sample (dict): The test sample containing 'tags' and 'embedding'.
        demos (list): List of demonstration samples, each containing 'tags' and 'embedding'.
        n (int): Number of top demonstrations to return.
    Returns:
        list: Sorted list of demonstration samples based on the selection criteria.
    """
    text_tags = set(test_sample['tags'])
    # Find all demos that have at least one tag in common with the test sample
    matches = []
    for demo in demos:
        demo_tags = set(demo['tags'])
        overlap = text_tags.intersection(demo_tags)
        if len(overlap) > 0:
            matches.append([overlap, demo])

    # Randomly shuffle the matches to ensure diversity in selection
    random.shuffle(matches)

    # Sort matches based on the number of overlapping tags and cosine similarity in descending order
    matches.sort(
        key=lambda x: (
            len(x[0]),
            cosine_similarity(test_sample['embedding'], x[1]['embedding'])
        ),
        reverse=True
    )

    # Extract the sorted demos
    sorted_demos = [m[1] for m in matches]

    # Return top n (or fewer if not enough demos)
    return sorted_demos[:n]

    

In [None]:
# Compute embeddings for the development set
train_embeddings = []
for i, row in df_train.iterrows():
    text = row['text']
    # Embed the text using OpenAI embeddings
    emb = embeddings.embed_query(text)
    train_embeddings.append(emb)

# Add the embeddings to the development set DataFrame
df_train['embedding'] = train_embeddings

In [None]:
# Compute embeddings for the testing set
test_embeddings = []
for i, row in testing_set.iterrows():
    text = row['text']
    # Embed the text using OpenAI embeddings
    emb = embeddings.embed_query(text)
    test_embeddings.append(emb)

# Add the embeddings to the testing set DataFrame
testing_set['embedding'] = test_embeddings

In [None]:
# convert development_set to a list of dictionaries
development_set = df_train.to_dict(orient='records')

# convert testing_set to a list of dictionaries
testing_set = testing_set.to_dict(orient='records')

# Select demonstrations based on the testing set and development set, number of demos to select is set to 3
demos = select_demos(testing_set[0], development_set, n=3)

# If no demonstrations are selected, assign random demonstrations from the development set
if len(demos) == 0:
    # assign random demonstrations
    demos = random.sample(development_set, 3)

print('Demonstration Count: %i' % len(demos))
print('Demonstration Tag Coverage: %s' % set([t for dem in demos for t in dem['tags']]))

# Prompting LLM to Generate the Code Representation

In [None]:
# Create a code string for the Section, Expression, Reference, Statement, Information, Definition, Rule, Exemption classes
code_string = """ 
class Section:
    \"""
    A bullet point in the legal text. Every bullet point starts a new Section,
    and sub-bullet points become subSections.

    Attributes:
        sectionNumber (str): The identifying number or label of this Section.
        sectionTitle (str): An optional title for this Section.
        parent (Optional[Section]): The parent Section if this is a nested (sub-)Section,
            otherwise None for a top-level Section.
        subSections (List[Section]): Any child Sections nested under this Section.
        expressions (List[Expression]): The Expression objects contained directly in this Section.
        statements (List[Statement]): The Statement objects contained directly in this Section.

    Methods:
        add_subsection(subsection: 'Section'):
            Adds a subsection (child) to this Section and sets the subsection's parent to self.

        add_expression(expression: 'Expression'):
            Adds an Expression object to this Section’s expressions list.

        add_statement(statement: 'Statement'):
            Adds a Statement object to this Section’s statements list.
    \"""

    def __init__(self,sectionNumber: str, sectionTitle: str = "", parent=None):
        self.sectionNumber: str = sectionNumber
        self.sectionTitle: str = sectionTitle
        self.parent: Optional['Section'] = parent
        self.subSections: List['Section'] = []
        self.expressions: List['Expression'] = []
        self.statements: List['Statement'] = []

    def add_subsection(self, subsection: 'Section'):
        self.subSections.append(subsection)
        subsection.parent = self
    def add_expression(self, expression: 'Expression'):
        self.expressions.append(expression)
    def add_statement(self, statement: 'Statement'):
        self.statements.append(statement)


class Expression:
    \"""
    A snippet of text within one bullet point (Section). Represents the smallest
    textual unit that can contain references in the text or other embedded elements.

    Each Expression belongs to exactly one Section.

    Attributes:
        section (Section): The Section in which this Expression is found.
        text (str): The textual content of the Expression.
        includes (Optional[List[Expression]]): A child Expression in a subsection that this Expression includes
    \"""

    def __init__(self, section: Section, text: str, includes=None):
        self.section: Section = section
        section.add_expression(self)
        self.text: str = text
        self.includes: Optional[List[Expression]] = includes if includes is not None else []

        
class Reference(Expression):
    \"""
    A type of Expression that refers to another part of the legal text.

    Attributes:
        target (Union[Expression, Statement]): The target Expression or Statement that this Reference points to.
    \"""

    def __init__(self, section: Section, text: str, target: Statement):
        super().__init__(section, text)
        self.target: Statement = target


class Statement:
    \"""
    A legal statement that can span multiple bullet points (Sections) if those
    bullet points are nested under a single conceptual clause. Statements often
    contain or refer to multiple Expressions.

    Attributes:
        section (Section): The Section that represents
            the location in the text where this Statement starts.
        relationships (dict of str -> List[Expression or Statement]): A dictionary of
            six possible relationship types, each mapping to a list of Expressions or Statements
            that are connected to this Statement or references that are present within the statement 
            in the specified manner.

    Relationship keys:
        - "refines": A list of References or Statements this Statement refines
            (providing more detail about).
        - "is_refined_by": A list of References or Statements that refine this Statement.
        - "has_exception": A list of References or Statements that are exceptions to this Statement.
        - "is_exception_to": A list of References or Statements for which this Statement is an exception.
        - "follows": A list of References or Statements that precede (act as post-conditions to) this Statement.
        - "is_followed_by": A list of References or Statements that follow this Statement.

    Methods:
        add_refines(target): Adds a target to the "refines" relationship.
        add_exception(exception): Adds a target to the "has_exception" relationship.
        add_follows(target): Adds a target to the "follows" relationship.
        add_is_refined_by(target): Adds a target to the "is_refined_by" relationship.
        add_is_exception_to(exception): Adds a target to the "is_exception_to" relationship.
        add_is_followed_by(target): Adds a target to the "is_followed_by" relationship.
    \"""

    def __init__(self, section: Optional[Section] = None):
        self.sections: Section = section
        self.relationships = {
            "refines": [],
            "is_refined_by": [],
            "has_exception": [],
            "is_exception_to": [],
            "follows": [],
            "is_followed_by": []
        }

    def add_refines(self, target: Union['Reference', 'Statement']):
        self.relationships["refines"].append(target)
    def add_exception(self, exception: Union['Reference', 'Statement']):
        self.relationships["has_exception"].append(exception)
    def add_follows(self, target: Union['Reference', 'Statement']):
        self.relationships["follows"].append(target)
    def add_is_refined_by(self, target: Union['Reference', 'Statement']):
        self.relationships["is_refined_by"].append(target)
    def add_is_exception_to(self, exception: Union['Reference', 'Statement']):
        self.relationships["is_exception_to"].append(exception)
    def add_is_followed_by(self, target: Union['Reference', 'Statement']):
        self.relationships["is_followed_by"].append(target)


class Information(Statement):
    \"""
    A type of Statement that represents something that is known or proved to be true.

    Attributes:
        description (List[Expression]): The Expressions that contains the factual information.
    \"""

    def __init__(self, section, description: Expression):
        super().__init__(section)
        self.description: List[Expression] = []
        if description is not None:
            self.description.append(description)


class Definition(Statement):
    \"""
    A type of Statement that defines a concept or term in the legal text.

    Attributes:
        defined_term (Expression): The Expression stating the term being defined.
        meaning (List[Expression]): One or more Expressions elaborating the meaning of the term.
        exclusions (List[Expression]): Expressions clarifying what the term excludes or does not cover.
    \"""

    def __init__(self, section, defined_term: Expression):
        super().__init__(section)
        self.defined_term: Expression = defined_term
        self.meaning: List[Expression] = []
        self.exclusions: List[Expression] = []


class Rule(Statement):
    \"""
    A Statement describing a legal rule, which may take one of four types: obligation,
    permission, prohibition, or penalty.

    Attributes:
        rule_type (int): An integer indicating which type of rule. Should be one of:
            OBLIGATION, PERMISSION, PROHIBITION, PENALTY.
        entity (Expression): The main entity (person, object, etc.) to which the rule applies.
        description (Expression): An Expression describing the rule.
        conditions (List[Expression]): Expressions indicating the conditions under which the rule applies.
    \"""

    OBLIGATION = 0
    PERMISSION = 1
    PROHIBITION = 2
    PENALTY = 3

    def __init__(self, section, entity: Expression):
        super().__init__(section)
        self.rule_type: int = None
        self.entity: Expression = entity
        self.description: Optional[Expression] = None
        self.conditions: List[Expression] = []


class Exemption(Statement):
    \"""
    A type of Statement indicating that a person, object, or situation is exempt
    from another rule or requirement.

    Attributes:
        description (List[Expression]): One or more Expressions describing the exemption.
    \"""

    def __init__(self, section=None, description: Optional[Expression] = None):
        super().__init__(section)
        self.description: List[Expression] = []
        if description is not None:
            self.description.append(description)
"""

In [None]:
# Define the prompt for converting text to Python code using the class structure above
prompt = """Read the text and convert it to Python code. Use the class structure detailed below to write code. Do not create your own names. Examples have been provided. 

Class Structure:
%s

Examples: 
%s

Text: %s
Code: """

# Define the prompt for converting text to Python code using the class structure above
prompt2 = """Read the text and convert it to Python code. Examples have been provided. 

Examples: 
%s

Text: %s
Code: """


def exec_prompt(test_sample, development_set):
    """
    Function to execute the prompt with the test sample and development set, returning the model's response.
    Args:
        test_sample (dict): The test sample containing 'text' and 'embedding'.
        development_set (list): The development set containing demonstration samples.
    Returns:
        str: The model's response containing the generated Python code.
    """
    # Select demonstrations based on the test sample and development set
    demos = select_demos(test_sample, development_set, n=3)

    # If no demonstrations are selected, assign random demonstrations from the development set
    if len(demos) == 0:
        demos = random.sample(development_set, 3)
    
    # Pass the code string and the selected demonstrations to the prompt
    p = prompt % ('```python\n' + code_string + '\n```', '\n\n'.join(
        ['Text: %s\nCode: ```python\n%s\n```' % (d['text'], d['code']) for d in demos]), test_sample['text'])

    # Call the model with the formatted prompt
    a = prompt_model(p)

    # Return the model's response
    return a

In [None]:
# Define the number of passes for the model to run
passes = 3 # Change this value to set the number of passes

# Execute the prompt for each test sample in the testing set for the specified number of passes
for j in range(passes):
    print('Pass %i' % (j + 1))
    print('Prompting for %i test samples' % len(testing_set), end='')
    answers = []
    # For each test sample, execute the prompt and collect the answers
    for t in testing_set:
        a = exec_prompt(t, development_set)
        answers.append(a)
        print('. ', end='')
    print()
    # Assign the generated code to the corresponding test sample in the testing set
    for i in range(len(testing_set)):
        testing_set[i]['code'] = answers[i]
    output_file = f'testing_set_pass_{j + 1}.csv' # Output file for each pass
    try:
        # Write the testing set with the generated code to a CSV file
        with open(output_file, 'w') as f:
            writer = csv.writer(f)
            writer.writerow(['text', 'code', 'tags'])
            for t in testing_set:
                writer.writerow([t['text'], t['code'], t['tags']])
        print(f'Wrote {output_file}')
    except Exception as e:
        # Handle any exceptions that occur during file writing
        print('Error writing file: ', e)
        break