In [1]:
# def classify_legal_case(case_text, category_embeddings, model, top_n=3, threshold=0.5, strategy="max"):
    
#     chunks = chunk_text(case_text)
#     chunk_embeddings = model.encode(chunks)
    
#     similarities = {}
    
#     for category_name, category_data in category_embeddings.items():
#         cat_embedding = category_data['embedding'].reshape(1, -1)
        
#         scores = cosine_similarity(chunk_embeddings, cat_embedding).flatten()
        
#         if strategy == "max":
#             similarity = float(np.max(scores)) 
#         elif strategy == "average":
#             similarity = float(np.mean(scores)) 
        
#         similarities[category_name] = similarity
    
#     sorted_categories = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
    
#     results = []
#     for category, similarity in sorted_categories[:top_n]:
#         if similarity >= threshold:
#             results.append(f"{category} ({similarity:.3f})")
    
#     return results if results else ["Unclassified"]

In [2]:
# df = pd.read_csv("SupremeCourt_cases710.csv") 

# predictions = []
# for case_text in tqdm(df["Case Content"], desc="Classifying cases"):
#     result = classify_legal_case(case_text, category_embeddings, model, top_n=1, threshold=0.5, strategy="max")
#     predictions.append("; ".join(result)) 

# df["Predicted_Category"] = predictions

In [3]:
import toons
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import sys


TAXONOMY_FILE = 'legal_taxonomy_with_embeddings.toon'
CASES_FILE = '9700cases_embedded.toon'
OUTPUT_FILE = '9700cases_classified.toon'
EMBEDDING_KEY = 'embeddings'

print("Loading data...")


# --- 1. Load Taxonomy ---
try:
    with open(TAXONOMY_FILE, 'r', encoding='utf-8') as f:
        taxonomy = toons.loads(f.read())
except FileNotFoundError:
    print(f"FATAL ERROR: Taxonomy file not found at {TAXONOMY_FILE}")
    sys.exit()
except Exception as e:
    print(f"Error loading taxonomy: {e}")
    sys.exit()


# --- 2. Load Cases ---
try:
    with open(CASES_FILE, 'r', encoding='utf-8') as f:
        cases_data = toons.loads(f.read())
except FileNotFoundError:
    print(f"FATAL ERROR: Cases file not found at {CASES_FILE}")
    sys.exit()
except Exception as e:
    print(f"Error loading cases: {e}")
    sys.exit()

print(f"Loaded {len(taxonomy)} categories and {len(cases_data)} cases.")


# --- 3. Prepare Taxonomy Embeddings for Fast Comparison ---
category_names = []
category_embeddings_list = []

for cat_name, cat_data in taxonomy.items():
    # Check for the embedding key (plural 'embeddings' from your image)
    if EMBEDDING_KEY in cat_data:
        category_names.append(cat_name)
        category_embeddings_list.append(np.array(cat_data[EMBEDDING_KEY]))
    else:
        print(f"Warning: Category '{cat_name}' has no key named '{EMBEDDING_KEY}'")

# Create one big matrix from all category embeddings
# Shape will be (num_categories, embedding_dimension)
category_embeddings_stack = np.vstack(category_embeddings_list)
print(f"Created category embedding matrix with shape: {category_embeddings_stack.shape}")


# --- 4. Classify Cases ---
print("Classifying cases...")
for case in tqdm(cases_data, desc="Classifying"):
    
    # a) Collect all chunk embeddings from the current case
    case_chunk_embeddings = []
    for key, value in case.items():
        if key.endswith("_embedding") and isinstance(value, list) and len(value) > 0:
            case_chunk_embeddings.append(np.array(value))
    
    # b) Check if we found any embeddings
    if not case_chunk_embeddings:
        case['case_category'] = "Unclassified (No Embeddings)"
        continue

    # Create a stack of this case's chunk embeddings
    # Shape will be (num_chunks_in_case, embedding_dimension)
    case_chunk_stack = np.vstack(case_chunk_embeddings)

    # c) Calculate cosine similarity in one batch operation
    # Compares all case chunks vs. all categories
    # Resulting shape: (num_chunks_in_case, num_categories)
    sim_matrix = cosine_similarity(case_chunk_stack, category_embeddings_stack)
    
    # d) Apply "max" strategy
    # Find the single highest score for each category (column)
    # Resulting shape: (num_categories,)
    max_sim_per_category = np.max(sim_matrix, axis=0)
    
    # e) Find the category with the highest "max" score
    best_category_index = np.argmax(max_sim_per_category)
    
    # f) Assign the category name
    case['case_category'] = category_names[best_category_index]

print("Classification complete.")

# --- 5. Save Results ---
try:
    print(f"Saving classified data to {OUTPUT_FILE}...")
    updated_toon_text = toons.dumps(cases_data)
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        f.write(updated_toon_text)
    print(f"Success! Saved updated data to {OUTPUT_FILE}")
except Exception as e:
    print(f"Error saving file: {e}")

Loading data...
Loaded 66 categories and 9760 cases.
Created category embedding matrix with shape: (66, 384)
Classifying cases...


Classifying: 100%|██████████| 9760/9760 [00:05<00:00, 1916.07it/s]


Classification complete.
Saving classified data to 9700cases_classified.toon...
Success! Saved updated data to 9700cases_classified.toon
