In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from semantic_similarity import calculate_semantic_similarity
from sentence_tokenizer import process_directory
from summarization import summarize_document_with_pipeline
from information_clustering import create_document_embeddings, perform_clustering_and_visualize

# Semantic Textual Similarity

In [None]:
# Process both directories
amici_dir = '../Data/Amici'
opinions_dir = '../Data/Opinions'
metadata_file = '../Data/metadata.csv'
amici_jsonl = '../Data/amici_sentences.jsonl'
opinions_jsonl = '../Data/opinions_sentences.jsonl'

In [None]:
amici_metadata = process_directory(amici_dir, amici_jsonl)
opinions_metadata = process_directory(opinions_dir, opinions_jsonl)

In [None]:
# Calculate and save the semantic similarity matrix
results = calculate_semantic_similarity(opinions_jsonl, amici_jsonl, metadata_file, output_dir="../Data/similarity_matrices", threshold=0.9, model_name="sentence-transformers/all-mpnet-base-v2")

In [None]:
case_of_interest = "Burwell v. Hobby Lobby"

if case_of_interest in results:
    print(f"Highly similar pairs for case '{case_of_interest}':")
    for pair in results[case_of_interest]["highly_similar_pairs"]:
        print("Opinion Sentence:", pair["opinion_sentence"])
        print("Amicus Sentence:", pair["amicus_sentence"])
        print("Similarity Score:", pair["similarity_score"])

In [None]:
cases = list(results.keys())
percent_copied_values = [results[case]["percent_copied"] for case in cases]

# Create a bar chart
plt.figure(figsize=(10, 6))
plt.bar(cases, percent_copied_values, color='mediumseagreen')
plt.axhline(y=2.7, color='black', linestyle='--', label='2.7% (Collins et al., 2015)')
plt.xlabel('Case')
plt.ylabel('Percent Copied')
plt.title('Percent of Opinion Text Similar to Amici Briefs by Case')
plt.xticks(rotation=45, ha='right')
plt.legend()
plt.tight_layout()  
plt.show()

In [None]:
# Initialize variables for calculating averages
total_percent_copied = 0
total_weighted_percent_copied = 0
total_words_in_all_opinions = 0

# Loop through each case in results
for case, case_data in results.items():
    percent_copied = case_data["percent_copied"]
    opinion_length_words = case_data["total_words"]  # Ensure this is in the results data

    # Add to the simple total
    total_percent_copied += percent_copied
    
    # Add to the weighted total
    total_weighted_percent_copied += percent_copied * opinion_length_words
    total_words_in_all_opinions += opinion_length_words

# Calculate averages
num_cases = len(results)
simple_average = total_percent_copied / num_cases if num_cases > 0 else 0
weighted_average = (total_weighted_percent_copied / total_words_in_all_opinions) if total_words_in_all_opinions > 0 else 0

# Print the results
print(f"Simple Average Percentage of Text Copied: {simple_average:.2f}%")
print(f"Opinion-Length Weighted Average Percentage of Text Copied: {weighted_average:.2f}%")


# Classificaiton

# Information Clustering

In [None]:
# Initialize the tokenizer
model_name = 'google/bigbird-roberta-large'
tokenizer = AutoTokenizer.from_pretrained(model_name)

def count_tokens(directory):
    token_counts = []
    for filename in os.listdir(directory):
        if filename.endswith(".txt"):
            file_path = os.path.join(directory, filename)
            with open(file_path, "r", encoding="utf-8") as file:
                text = file.read()

            # Tokenize the text to count tokens
            tokens = tokenizer(text, truncation=False)
            num_tokens = len(tokens['input_ids'])
            token_counts.append(num_tokens)

    return token_counts

# Count tokens in each document
directory = "Data/Amici"
token_counts = count_tokens(directory)

# Plot histogram of token counts
plt.figure(figsize=(10, 6))
plt.hist(token_counts, bins=30, color='mediumseagreen', edgecolor='black')
plt.axvline(x=4096, color='black', linestyle='--', label='4096 Token Limit')
plt.xlabel("Number of Tokens")
plt.ylabel("Number of Documents")
plt.title("Histogram of Amicus Curiae Token Counts")
plt.legend()
plt.show()

In [None]:

# Define paths and call the embedding creation function
DATA_DIRS = ["../Data/Amici"]
METADATA_FILE = "../Data/metadata.csv"
OUTPUT_EMBEDDINGS = "../Data/embeddings.pt"
OUTPUT_LABELS = "../Data/labels.csv"


# create_document_embeddings(DATA_DIRS, METADATA_FILE, OUTPUT_EMBEDDINGS, OUTPUT_LABELS)

# Define paths and call the clustering and visualization function
N_CLUSTERS = 6
OUTPUT_PLOT = "../Data/clustering_visualization.png"
perform_clustering_and_visualize(OUTPUT_EMBEDDINGS, OUTPUT_LABELS, N_CLUSTERS, OUTPUT_PLOT)


# Summarization

In [None]:
# Example usage
file_path = "../Data/nfib v sebelius syllabus.txt"
model_names = [
    "google/flan-t5-small",
    "google/flan-t5-base",
    "google/flan-t5-large",
    "google/flan-t5-xl"]

summaries = summarize_document_with_pipeline(file_path, model_names, max_input_length=512)


model_names = [
    "allenai/led-base-16384",
    "allenai/led-base-16384-ms2",
    "allenai/led-base-16384-cochrane",
    "allenai/led-large-16384", 
    "allenai/led-large-16384-arxiv", 
]
summaries = summarize_document_with_pipeline(file_path, model_names, max_input_length=2900)