In [1]:
import os
import numpy as np
import pandas as pd
import faiss
import tensorflow_hub as hub
from pprint import pprint

import openai
openai.api_key = os.environ["OPENAI_API_KEY"]


In [2]:
legal_docs = pd.read_csv("./data/legal_text_classification.csv")
legal_docs.head()

Unnamed: 0,case_id,case_outcome,case_title,case_text
0,Case1,cited,Alpine Hardwood (Aust) Pty Ltd v Hardys Pty Lt...,Ordinarily that discretion will be exercised s...
1,Case2,cited,Black v Lipovac [1998] FCA 699 ; (1998) 217 AL...,The general principles governing the exercise ...
2,Case3,cited,Colgate Palmolive Co v Cussons Pty Ltd (1993) ...,Ordinarily that discretion will be exercised s...
3,Case4,cited,Dais Studio Pty Ltd v Bullett Creative Pty Ltd...,The general principles governing the exercise ...
4,Case5,cited,Dr Martens Australia Pty Ltd v Figgins Holding...,The preceding general principles inform the ex...


In [3]:
print (legal_docs.shape)
print (legal_docs["case_outcome"].unique())


array(['cited', 'applied', 'followed', 'referred to', 'related',
       'considered', 'discussed', 'distinguished', 'affirmed', 'approved'],
      dtype=object)

In [4]:
pprint (legal_docs.iloc[100]["case_text"])

('Gedeon v Commissioner of New South Wales Crime Commission [2008] HCA 43 ; '
 '(2008) 82 ALJR 1465 at [43] the High Court said: The expression '
 '"jurisdictional fact" was used somewhat loosely in the course of '
 'submissions. Generally the expression is used to identify a criterion the '
 'satisfaction of which enlivens the exercise of the statutory power or '
 'discretion in question. If the criterion be not satisfied then the decision '
 'purportedly made in exercise of the power or discretion will have been made '
 'without the necessary statutory authority required of the decision maker.')


In [5]:
print (legal_docs.info())
legal_docs.isna().sum()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 24985 entries, 0 to 24984
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   case_id       24985 non-null  object
 1   case_outcome  24985 non-null  object
 2   case_title    24985 non-null  object
 3   case_text     24809 non-null  object
dtypes: object(4)
memory usage: 780.9+ KB
None


case_id           0
case_outcome      0
case_title        0
case_text       176
dtype: int64

In [3]:
# impute missing text with title
legal_docs["case_text"] = np.where(legal_docs["case_text"].isna(), legal_docs["case_title"], legal_docs["case_text"])

In [4]:
module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
#module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"

model = hub.load(module_url)
print ("module %s loaded" % module_url)
def embed(input):
  return model(input)

module https://tfhub.dev/google/universal-sentence-encoder/4 loaded


In [11]:
# without token chunking
text_all = legal_docs["case_text"].to_list()
docs_chunk_size = 1024
current = 0

# chunk and create embedded vectors
text_embeddings_all = []
for i in range(len(text_all)):
  # chunk docs
  start_i = i*docs_chunk_size
  end_i = (i+1)*docs_chunk_size
  end_i = end_i if end_i < len(text_all) else len(text_all)

  print (start_i, end_i)

  docs = text_all[i*docs_chunk_size : (i+1)*docs_chunk_size]
  if not docs: break

  text_embeddings = embed(docs)
  text_embeddings_np = text_embeddings.numpy()
  text_embeddings_all.append(text_embeddings_np)

  current += docs_chunk_size

# handle last chunk before creating reshaping array
if len(text_embeddings_all[-1]) != docs_chunk_size:
  last_chunk = text_embeddings_all.pop()
  last_chunk_np = np.array(last_chunk)
  text_embeddings_all_np = np.array(text_embeddings_all)
  
  n_rows = text_embeddings_all_np.shape[0] * text_embeddings_all_np.shape[1]
  text_embeddings_all_np = text_embeddings_all_np.reshape(n_rows, -1)
  text_embeddings_all_np = np.r_[text_embeddings_all_np, last_chunk_np]
else:
  text_embeddings_all_np = np.array(text_embeddings_all)
  n_rows = text_embeddings_all_np.shape[0] * text_embeddings_all_np.shape[1]
  text_embeddings_all_np = text_embeddings_all_np.reshape(n_rows, -1)


0 1024
1024 2048
2048 3072
3072 4096
4096 5120
5120 6144
6144 7168
7168 8192
8192 9216
9216 10240
10240 11264
11264 12288
12288 13312
13312 14336
14336 15360
15360 16384
16384 17408
17408 18432
18432 19456
19456 20480
20480 21504
21504 22528
22528 23552
23552 24576
24576 24985
25600 24985


In [12]:
# NOT WORKING - not all docs are chunked evenly (e.g. one doc can have 8 embedded vectors and another can have 1 because they are of different length).
# Need to map vector index to doc index somehow. Probably an external index that is used later on during the search when converting the chunk index 
# to a doc index...

# with token chunking

# text_all = legal_docs["case_text"].to_list()
# docs_chunk_size = 1024
# tokens_chunk_size = 512
# current = 0

# # chunk and create embedded vectors
# text_embeddings_all = []
# for i in range(len(text_all)):
#   # chunk docs
#   start_i = i*docs_chunk_size
#   end_i = (i+1)*docs_chunk_size
#   end_i = end_i if end_i < len(text_all) else len(text_all)

#   print (start_i, end_i)

#   docs = text_all[i*docs_chunk_size : (i+1)*docs_chunk_size]
#   if not docs: break

#   # chunk docs tokens
#   for doc in docs:
#     current_token = 0
#     while current_token < len(doc):
#       tokens = [doc[current_token : current_token + tokens_chunk_size]]
#       text_embeddings = embed(tokens)
#       text_embeddings_np = text_embeddings.numpy()
#       text_embeddings_all.append(text_embeddings_np)

#       current_token += tokens_chunk_size

#   current += docs_chunk_size

# # handle last chunk before creating reshaping array
# if len(text_embeddings_all[-1]) != docs_chunk_size:
#   last_chunk = text_embeddings_all.pop()
#   last_chunk_np = np.array(last_chunk)
#   text_embeddings_all_np = np.array(text_embeddings_all)
  
#   n_rows = text_embeddings_all_np.shape[0] * text_embeddings_all_np.shape[1]
#   text_embeddings_all_np = text_embeddings_all_np.reshape(n_rows, -1)
#   text_embeddings_all_np = np.r_[text_embeddings_all_np, last_chunk_np]
# else:
#   text_embeddings_all_np = np.array(text_embeddings_all)
#   n_rows = text_embeddings_all_np.shape[0] * text_embeddings_all_np.shape[1]
#   text_embeddings_all_np = text_embeddings_all_np.reshape(n_rows, -1)
  


In [13]:
text_embeddings_all_np.shape

(24985, 512)

In [14]:
index = faiss.IndexFlatL2(text_embeddings_all_np.shape[1])
index.add(text_embeddings_all_np)
faiss.write_index(index, "court_text.index")