In [1]:
import pandas as pd
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

# Download necessary NLTK data
nltk.download('stopwords')
nltk.download('wordnet')

from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import numpy as np
import pickle

import faiss

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/infres/apereira-22/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/infres/apereira-22/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
def load_dataset(csv_path):
    """
    Load the dataset from a CSV file.
    """
    df = pd.read_csv(csv_path)
    return df

def perform_eda(df):
    """
    Perform basic EDA on the dataset.
    """
    print("First 5 records:")
    print(df.head())

    print("\nDataset Info:")
    print(df.info())

    print("\nMissing Values:")
    print(df.isnull().sum())

    print("\nClass Distribution:")
    print(df['case_outcome'].value_counts())

def clean_text(text):
    """
    Clean the input text by:
    - Lowercasing
    - Removing special characters and digits
    - Removing stopwords
    - Lemmatizing
    """
    # Lowercase
    text = text.lower()

    # Remove special characters and digits
    text = re.sub(r'[^a-z\s]', '', text)

    # Tokenize
    tokens = text.split()

    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]

    # Lemmatize
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(word) for word in tokens]

    # Join back to string
    cleaned_text = ' '.join(tokens)
    return cleaned_text

def preprocess_texts(df):
    """
    Apply text cleaning to 'case_title' and 'case_text'.
    """
    df['cleaned_title'] = df['case_title'].apply(lambda x: clean_text(str(x)))
    df['cleaned_text'] = df['case_text'].apply(lambda x: clean_text(str(x)))
    return df

def split_into_chunks(text, max_length=512):
    """
    Split text into chunks of maximum 'max_length' words.
    """
    words = text.split()
    chunks = [' '.join(words[i:i + max_length]) for i in range(0, len(words), max_length)]
    return chunks

def split_texts(df, max_length=512):
    """
    Apply text splitting to 'cleaned_text'.
    Each case can have multiple chunks.
    """
    df['text_chunks'] = df['cleaned_text'].apply(lambda x: split_into_chunks(x, max_length))
    return df

def generate_embeddings(df, model_name='sentence-transformers/all-MiniLM-L6-v2'):
    """
    Generate embeddings for each text chunk.
    """
    model = SentenceTransformer(model_name)
    all_chunks = df['text_chunks'].explode().tolist()
    
    print("Generating embeddings...")
    embeddings = model.encode(all_chunks, show_progress_bar=True, convert_to_numpy=True)
    
    return all_chunks, embeddings

def build_faiss_index(embeddings, index_path):
    """
    Build and save a FAISS index from embeddings.
    """
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)  # Using L2 distance; consider IndexHNSWFlat for larger datasets
    index.add(embeddings)
    faiss.write_index(index, index_path)
    print(f"FAISS index built and saved to {index_path}")
    return index

def save_chunks(all_chunks, doc_ids, save_path='data/processed_data/documents.pkl'):
    """
    Save the text chunks and their corresponding document IDs.
    """
    with open(save_path, 'wb') as f:
        pickle.dump({'chunks': all_chunks, 'doc_ids': doc_ids}, f)
    print(f"Documents saved to {save_path}")

def save_preprocessed_dataframe(df, save_path):
    """
    Save the preprocessed dataframe to a CSV file.
    """
    df.to_csv(save_path, index=False)
    print(f"Preprocessed data saved to {save_path}")

In [3]:
csv_path = "raw_legal_text_classification.csv"
df = load_dataset(csv_path)

In [4]:
perform_eda(df)

First 5 records:
  case_id case_outcome                                         case_title  \
0   Case1        cited  Alpine Hardwood (Aust) Pty Ltd v Hardys Pty Lt...   
1   Case2        cited  Black v Lipovac [1998] FCA 699 ; (1998) 217 AL...   
2   Case3        cited  Colgate Palmolive Co v Cussons Pty Ltd (1993) ...   
3   Case4        cited  Dais Studio Pty Ltd v Bullett Creative Pty Ltd...   
4   Case5        cited  Dr Martens Australia Pty Ltd v Figgins Holding...   

                                           case_text  
0  Ordinarily that discretion will be exercised s...  
1  The general principles governing the exercise ...  
2  Ordinarily that discretion will be exercised s...  
3  The general principles governing the exercise ...  
4  The preceding general principles inform the ex...  

Dataset Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 24985 entries, 0 to 24984
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        ------

In [5]:
df = df.dropna(subset=['case_text'])
perform_eda(df)

First 5 records:
  case_id case_outcome                                         case_title  \
0   Case1        cited  Alpine Hardwood (Aust) Pty Ltd v Hardys Pty Lt...   
1   Case2        cited  Black v Lipovac [1998] FCA 699 ; (1998) 217 AL...   
2   Case3        cited  Colgate Palmolive Co v Cussons Pty Ltd (1993) ...   
3   Case4        cited  Dais Studio Pty Ltd v Bullett Creative Pty Ltd...   
4   Case5        cited  Dr Martens Australia Pty Ltd v Figgins Holding...   

                                           case_text  
0  Ordinarily that discretion will be exercised s...  
1  The general principles governing the exercise ...  
2  Ordinarily that discretion will be exercised s...  
3  The general principles governing the exercise ...  
4  The preceding general principles inform the ex...  

Dataset Info:
<class 'pandas.core.frame.DataFrame'>
Index: 24809 entries, 0 to 24984
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        -----------

In [6]:
df['case_id'] = df['case_id'].str.replace('Case', '').astype(int)
df['case_outcome'] = df['case_outcome'].astype(str)
df['case_title'] = df['case_title'].astype(str)
df['case_text'] = df['case_text'].astype(str)
perform_eda(df)

First 5 records:
   case_id case_outcome                                         case_title  \
0        1        cited  Alpine Hardwood (Aust) Pty Ltd v Hardys Pty Lt...   
1        2        cited  Black v Lipovac [1998] FCA 699 ; (1998) 217 AL...   
2        3        cited  Colgate Palmolive Co v Cussons Pty Ltd (1993) ...   
3        4        cited  Dais Studio Pty Ltd v Bullett Creative Pty Ltd...   
4        5        cited  Dr Martens Australia Pty Ltd v Figgins Holding...   

                                           case_text  
0  Ordinarily that discretion will be exercised s...  
1  The general principles governing the exercise ...  
2  Ordinarily that discretion will be exercised s...  
3  The general principles governing the exercise ...  
4  The preceding general principles inform the ex...  

Dataset Info:
<class 'pandas.core.frame.DataFrame'>
Index: 24809 entries, 0 to 24984
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        -----

In [7]:
df = preprocess_texts(df)

In [8]:
df['cleaned_text'].apply(lambda x: len(x.split())).describe()

count    24809.000000
mean       222.178000
std        558.954086
min         11.000000
25%         74.000000
50%        120.000000
75%        212.000000
max      10998.000000
Name: cleaned_text, dtype: float64

In [9]:
df = split_texts(df)

In [11]:
all_chunks, embeddings = generate_embeddings(df)

Generating embeddings...


Batches:   0%|          | 0/888 [00:00<?, ?it/s]

In [None]:
all_chunks, embeddings = generate_embeddings(df)
with open('temp_all_chunks.pkl', 'wb') as f:
    pickle.dump(all_chunks, f)
np.save('temp_embeddings.npy', embeddings)

In [None]:
all_chunks = pickle.load(open('temp_all_chunks.pkl', 'rb'))
embeddings = np.load('temp_embeddings.npy')

In [25]:
len(all_chunks), embeddings.shape

(28391, (28391, 384))

In [5]:
index = build_faiss_index(embeddings, 'faiss_index.index')

FAISS index built and saved to faiss_index.index


In [26]:
doc_ids = df.loc[df.index.repeat(df['text_chunks'].str.len()), 'case_id'].tolist()
save_chunks(all_chunks, doc_ids, save_path="documents.pkl")

Documents saved to documents.pkl


In [28]:
len(doc_ids)

28391

In [22]:
save_preprocessed_dataframe(df, 'preprocessed_dataframe.csv')

Preprocessed data saved to preprocessed_dataframe.csv
