# LawChat

**Use natural language to query a library of previosuly cited legal cases**

We make use of LangChain for:

- Document Loaders: WebBaseLoader
- Text Splitter: TokenTextSplitter
- Vector Store: Cassandra
- Models: OpenAIEmbeddings and ChatOpenAI
- Chain: FLARE

We will be using the Australasian Legal Information Institute as our source legislation and court judgments ("case law").

## Configuration

**Create your `.env1` file**

1. Copy the .env.example file to `.env`
2. Specify your Astra and openAI parameters.

In [1]:
import os
from dotenv import load_dotenv

ASTRA_DB_KEYSPACE = os.environ['ASTRA_DB_KEYSPACE']
ASTRA_DB_SECURE_BUNDLE_PATH = os.environ['ASTRA_DB_SECURE_BUNDLE_PATH']
ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN']
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']

## Astra DB Connectivity

In [2]:
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.cqlengine import connection


# load settings and keys
#settings = Settings


def getCluster():
    """
    Create a Cluster instance to connect to Astra DB.
    Uses the secure-connect-bundle and the connection secrets.
    """
    cloud_config = {"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH}
    auth_provider = PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN)
    return Cluster(cloud=cloud_config, auth_provider=auth_provider)


def get_astra():
    """
    This function is used by LangChain Vectorstore.
    """
    cluster = getCluster()
    astraSession = cluster.connect()
    return astraSession, ASTRA_DB_KEYSPACE


def initSession():
    """
    Create the DB session and return it to the caller.
    Most important, the session is also set as default and made available
    to the object mapper through global settings. I.e., no need to actually
    do anything with the return value of this function.
    """
    cluster = getCluster()
    session = cluster.connect()
    session.set_keyspace("lawchat")
    connection.register_connection("my-astra-session", session=session)
    connection.set_default_connection("my-astra-session")
    return connection

#### Define the Vector Store

In [7]:
from langchain.vectorstores import Cassandra
from langchain.embeddings.openai import OpenAIEmbeddings

# define Embedding model
embeddings = OpenAIEmbeddings()

# Set up the vector store
print("Setup Vector Store")
session, keyspace = get_astra()
vectorstore = Cassandra(
    embedding=embeddings,
    session=session,
    keyspace=keyspace,
    table_name="nswsc",
)

Setup Vector Store


## Load library data

We are using data from the Australasian Legal Information Institute. AustLII maintains collections of primary materials: legislation and court judgments ("case law"). 

For this project, we are sourcing case law from the Supreme Court of New South Wales.

#### Utility functions

In [8]:
import re

"""
Function to clean text from web pages
"""
def clean_text(text: str):
    # Normalize line breaks to \n\n (two new lines)
    text = text.replace("\r\n", "\n\n")
    text = text.replace("\r", "\n\n")

    # Replace two or more spaces with a single space
    text = re.sub(" {2,}", " ", text)

    # Remove leading spaces before removing trailing spaces
    text = re.sub("^[ \t]+", "", text, flags=re.MULTILINE)

    # Remove trailing spaces before removing empty lines
    text = re.sub("[ \t]+$", "", text, flags=re.MULTILINE)

    # Remove empty lines
    text = re.sub("^\s+", "", text, flags=re.MULTILINE)

    return text


In [9]:
import tiktoken

"""
Function to calculate the number of tokens in a text string.
"""

encoding = tiktoken.get_encoding("cl100k_base")

def num_tokens_from_string(string: str) -> int:
    num_tokens = len(encoding.encode(string))
    return num_tokens

#### Get data files

We load a number of HTML pages using the LangChain WebBaseLoader. Each of those pages contains lots of superfluous content so we extract only the relevant article context.

In [10]:
from langchain.document_loaders import WebBaseLoader

urls = [
    "https://www.austlii.edu.au/cgi-bin/viewdoc/au/cases/nsw/NSWSC/1998/423.html",
    "https://www8.austlii.edu.au/cgi-bin/viewdoc/au/cases/nsw/NSWSC/2002/949.html",
    "https://www8.austlii.edu.au/cgi-bin/viewdoc/au/cases/nsw/NSWSC/1998/4.html",
    "https://www8.austlii.edu.au/cgi-bin/viewdoc/au/cases/nsw/NSWSC/2005/1181.html",
    "https://www8.austlii.edu.au/cgi-bin/viewdoc/au/cases/nsw/NSWSC/1998/483.html"
    ]
headers = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36"
}

print("Loading Data")
url_loaders = WebBaseLoader(urls, header_template=headers)
data = url_loaders.load()

"""
Extract only the actual Article content from the web page and clean
"""
print("Cleaning Data")
for i, d in enumerate(data):
    d.page_content = ""
    source = d.metadata['source']
    thedoc = WebBaseLoader(source, header_template=headers).scrape()
    # extract only the Article content from the web page
    td = thedoc.find("article", {"class": "the-document"}).text
    d.page_content = clean_text(td)
    data[i] = d

print (f"Number of documents: {len(urls)}")
print (f"Number of tokens: {num_tokens_from_string(data[0].page_content)}")


Loading Data
Cleaning Data
Number of documents: 5
Number of tokens: 33772


#### Split the data into chunks

In [12]:
from langchain.text_splitter import TokenTextSplitter

CHUNK_SIZE = 500

# Chunk the data
print("Splitting Data")
text_splitter = TokenTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=50)
docs = text_splitter.split_documents(data)
print(f"Number of chunks: {len(docs)}")

Splitting Data
Number of chunks: 226


#### Store data and embeddings in Astra DB

In [13]:
print("Adding texts to Vector Store")
texts, metadatas = zip(*((doc.page_content, doc.metadata) for doc in docs))
vectorstore.add_texts(texts=texts, metadatas=metadatas)

Adding texts to Vector Store


['d93f992af1324a1f9c3eb75b35edf6cb',
 'ecf50703bc1e43549c83c8b19e53af30',
 '18172ccd31c1419fbfea182d5db029fe',
 '0894067ba4c74d73808341f1b8999920',
 'a7ce0b0981ad458b9bdd8ec9ed2207b5',
 '33ac26f6f84c4492915641779b0808b1',
 '9364f696959e4d6689402ef4f96b21d2',
 '16d84c00e31c4cb9ad19e3a89ee99b4d',
 'de52c686a3ab40e0bb196075903af171',
 'c62bcd6256354df580076792d3cf8482',
 '8f1fd76faa424ee9b7bdc29947c1732a',
 '3e22cf925f094428acbd2cc15a88b2f8',
 '9f98c9e6a97449ce9c158b17b4c26219',
 '13f5bdba0a504fe1bc3660d8567ecdca',
 '41681f9491074372afb9e095c13341df',
 '517fd6d020224311a5a1e74022d8b34d',
 '8fbd116e2dbe4fa5a5c36b132276f41c',
 '184cd8112eb7415d8f7ed51133853154',
 '1241c9310c824c459c08ce22f661b7cb',
 '8f9c1ba225ba46d3a6b1b7f1a33ac41d',
 'ed5ae1a9f20c46dbaf960da8ea8a048c',
 '541421c166854bbc9b81d4067507a76b',
 '7e51186361094256982493a7c8f0183d',
 '089e5a630abc413bb08bc8378136c43c',
 '08d16b326f164a6480b71c4308603541',
 '46c9c391924a4e009379ab8856625d44',
 '819893904c3b4604a614148540af53d9',
 

## Query

#### Define the Retriever
by setting "k" in kwargs, we can control how many results are retrieved from the vector store and passed into the Flare chain.

- Ideally, we should not provide a "K", but the only way I can get this to work is limit the K above to a single result in order to not breach the token limit of 4096

In [14]:
import sys
from langchain.chains import FlareChain
from langchain.chat_models import ChatOpenAI

"""
Due to the default FLARE chain having the openAI model hardcoded 
to a 4096 token model, we have to limit the number of tokens. 
To do that we use k=1.
"""
#retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
retriever = vectorstore.as_retriever()

#### Define FLARE chain

Appears that the model attribute is ignored as an error is thrown if more than 4096 tokens are supplied.

- the only way I can get this to work is to either:

    - limit the K above to a single result in order to not breach the token limit. But that means we are not retreiving enough context from the data.
    - reduce the chunk size to create smaller chunks and thus reduce the number of tokens passed to the LLM.

In [None]:
flare = FlareChain.from_llm(
    ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k"),
    retriever=retriever,
    max_generation_len=164,
    min_prob=0.3,
)

# Define your query
query = "List the orders from the case AMALGAMATED vs MARSDEN"

flare_result = flare.run(query)

print(f"QUERY: {query}\n\n")
print(f"FLARE RESULT:\n    {flare_result}\n\n")

#### Use OpenAI LLM for comparison of not using FLARE

In [None]:
from langchain.llms import OpenAI

llm = OpenAI()
llm_result = llm(query)
print(f"LLM RESULT:\n    {llm_result}\n\n")