In [1]:
!pip install textacy
!pip install faiss-gpu
!pip install chromadb

Collecting textacy
  Downloading textacy-0.13.0-py3-none-any.whl.metadata (5.3 kB)
Collecting floret~=0.10.0 (from textacy)
  Downloading floret-0.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.1 kB)
Collecting pyphen>=0.10.0 (from textacy)
  Downloading pyphen-0.17.2-py3-none-any.whl.metadata (3.2 kB)
Downloading textacy-0.13.0-py3-none-any.whl (210 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.7/210.7 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading floret-0.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (320 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.4/320.4 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyphen-0.17.2-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: pyphen, floret, textacy
Successful

In [2]:
%%capture
# Initiliase Ollama server on a thread

#Download ollama
!curl -fsSL https://ollama.com/install.sh | sh
import subprocess
process = subprocess.Popen("ollama serve", shell=True) #runs on a different thread
#Download model
!ollama pull llama3
!pip install ollama
import ollama
model_name = 'llama3'

In [None]:
import re
import asyncio
import pandas as pd
import numpy as np
from pydantic import BaseModel
from ollama import AsyncClient
import textacy.datasets
import requests
import json
from sentence_transformers import SentenceTransformer
from textwrap import shorten
import faiss
from textacy.representations.vectorizers import Vectorizer
from sklearn.metrics.pairwise import cosine_similarity
import chromadb

In [None]:
# Load Supreme Court dataset
sc = textacy.datasets.SupremeCourt()
sc.download()

# Area code classification for 15 labels
# print(sc.issue_area_codes)

# Create a DataFrame from the dataset
cases = [{"text": record.text} for record in sc.records()]
cases_df = pd.DataFrame(cases)

# Display the first few rows of the dataset
cases_df.head()
cases_df = cases_df[:100]

In [None]:
# Get words from model's perpective
response = ollama.chat(model=model_name, messages=[
  {
    'role': 'user',
    'content': '''Give me a list of words and phrases (atleast 15) which would be present in a case text to classify it into following categories. Do not write any explaination or warnings.
            - Criminal Procedure: Cases involving the process of investigating and prosecuting crimes.
            - Civil Rights: Cases about the rights of individuals to receive equal treatment.
            - First Amendment: Cases addressing issues related to freedom of speech, religion, or press.
            - Due Process: Cases focusing on legal safeguards ensuring fair treatment.
            - Privacy: Cases dealing with the right to privacy.
            - Attorneys: Cases about the legal profession and lawyer regulations.
            - Unions: Cases concerning labor unions and collective bargaining.
            - Economic Activity: Cases involving business, trade, or commerce.
            - Judicial Power: Cases addressing the powers of courts and the judiciary.
            - Federalism: Cases about the division of power between state and federal governments.
            - Interstate Relations: Cases about interactions between states.
            - Federal Taxation: Cases involving federal tax laws.
            - Miscellaneous: Cases that do not clearly fit into any of the above categories.
            - Private Action: Cases involving disputes between private individuals or entities.

    Format the output in the following way:
    'Category': words, phrase ...
      ''',
  },
])

In [None]:
# print(response['message']['content'])

def parse_text_to_dict(text):
    result_dict = {}
    lines = text.strip().split('\n')
    
    for line in lines:
        line = line.strip()
        if not line or ':' not in line:
            continue  # Ignore headers or tails
        match = re.match(r"^(.*?):\s*(.*)$", line)
        if match:
            category = match.group(1).strip()
            values = [word.strip() for word in match.group(2).split(',')]
            result_dict[category] = values
    
    return result_dict

def clean_text(text):
    if not isinstance(text, str):
        text = str(text)  # Ensure input is a string
    return re.sub(r'[^a-zA-Z0-9 ]', '', text).strip().lower()

def filter_result_dict(result_dict, valid_categories):
    cleaned_valid_categories = {clean_text(str(value)) for value in valid_categories.values()}
    cleaned_result_dict = {clean_text(key): values for key, values in result_dict.items()}
    return {key: values for key, values in cleaned_result_dict.items() if key in cleaned_valid_categories}

# Convert text to dictionary
result = parse_text_to_dict(response['message']['content'])
# for i in result:
#     print(i)
#     print(result[i])

filtered_result = filter_result_dict(result, sc.issue_area_codes)
for i in filtered_result:
    print(i)
    print(filtered_result[i])
    

In [None]:
# Convert dataset records to a DataFrame
records = list(sc.records())
df = pd.DataFrame(records)

# Display the first few rows
print(df.head())

In [None]:
from sentence_transformers import SentenceTransformer
import chromadb
import ollama
import nltk
from nltk.corpus import stopwords
import re

# Download necessary resources
nltk.download("stopwords")

def classify_case_rag(case_text, classification_dict):
    """
    Dynamically classify a case using Retrieval-Augmented Generation (RAG) by querying relevant legal topics
    from all categories in the dictionary.
    
    Args:
        case_text (str): The text of the case to be classified.
        classification_dict (dict): Dictionary mapping classification categories to relevant keywords.
    
    Returns:
        str: The best-matching class name for the case.
    """
    
    # Load the embedding model
    model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    
    # Initialize an in-memory ChromaDB collection (not stored persistently)
    chroma_client = chromadb.PersistentClient(path="./chroma_db")
    collection = chroma_client.get_or_create_collection(name="temp_case_classification")

    # Populate ChromaDB dynamically with classification categories and keywords
    for class_name, keywords in classification_dict.items():
        for keyword in keywords:
            embedding = model.encode(keyword).tolist()
            collection.add(embeddings=[embedding], metadatas=[{"class_name": class_name}], ids=[keyword])
    
    # Encode the case text for retrieval
    case_embedding = model.encode(case_text).tolist()
    
    # Retrieve the most relevant legal categories
    results = collection.query(query_embeddings=[case_embedding], n_results=3)
    
    retrieved_classes = [
        res["class_name"] for res in results["metadatas"][0]
    ] if results["metadatas"] else []
    
    if not retrieved_classes:
        return "No relevant classification found"
    
    # Prepare relevant context for LLaMA 3
    retrieved_text = "\n".join(retrieved_classes)
    
    prompt = f"""
    You are an AI assistant trained in Supreme Court case classification. Given the following retrieved categories,
    classify the new case into the most appropriate issue area.

    Retrieved Relevant Categories:
    {retrieved_text}

    Case Text:
    {case_text[:1000]}  # Limiting case text for concise processing

    Please reply with the category only and nothing else."""

    # Query LLaMA 3 using the Ollama instance
    response = ollama.chat(model='llama3', messages=[{"role": "user", "content": prompt}])
    
    return response['message']['content']


In [None]:
%%capture
# Process and classify cases
classified_cases = []

def process_cases():
    """
    Process all cases synchronously and classify them using the local API.
    """
    for idx, row in cases_df.iterrows():
        print(f"Processing case {idx + 1}/{len(cases_df)}...")

        # Classify the case
        class_name = classify_case_rag(row["text"], filtered_result)  # Pass the preprocessed text

        if class_name:
            classified_cases.append({"text": row["text"], "class_name": class_name})
        else:
            print(f"Failed to classify case {idx + 1}")
        # break

    
    print("Processing completed.")

# Run the case processing
process_cases()



In [None]:
# Save the results to a DataFrame
classified_df = pd.DataFrame(classified_cases)


# filling none for Nan values
classified_df.fillna({'class_name': 'None'}, inplace=True)


# Display the results
classified_df.head()

In [None]:
classified_df.to_csv("/kaggle/working/classified_llama3.csv")

In [None]:
for rows in classified_df.iterrows():
    print(rows[1]['class_name'])
    break