In [None]:
from datasets import load_dataset, Value, ClassLabel
import pandas as pd
import numpy as np
import random
from tqdm import tqdm

In [None]:
# # Only need to run this once to download the dataset

# # Load dataset and take a random sample of 100,000 rows
# sample_size = 100000

# ds = load_dataset('free-law/Caselaw_Access_Project', streaming=True, split='train')

# # Estimated rows (from HuggingFace)
# total_size = 4284276

# # Calculate sampling probability
# sample_size = 100_000
# sampling_prob = sample_size / total_size

# # Take reservoir sample
# sampled_data = []
# for i, example in tqdm(enumerate(ds)):
#     if random.random() < sampling_prob:
#         sampled_data.append(example)
    
#     # Print progress periodically
#     if i % 100000 == 0:
#         print(f"Processed {i} examples, currently have {len(sampled_data)} samples")

In [None]:
# Convert to pandas df
df = pd.DataFrame(sampled_data)

In [None]:
# Convert appropriate columns to numeric
for col in df.columns:
    try:
        df[col] = pd.to_numeric(df[col])
    except:
        continue

In [None]:
df.shape

Save data sample for future use without having to load again.

In [None]:
# Pick up from here
# df.to_parquet('../data/caselaw_sample.parquet')
df = pd.read_parquet('../data/caselaw_sample.parquet')

### Exploratory analysis of the data.

#### Word Counts

In [None]:
# Interested in word count statistics
df['word_count'].describe()

In [None]:
# Mean is around 1800 words, median is around 1100 words. Choosing cutoff of 1200.
# New cutoffs between 5 pages and 10 pages (roughly 500 words per page)
floor_cutoff = 500
ceiling_cutoff = 5000
df_filtered = df.query(f'word_count < {ceiling_cutoff} & word_count > {floor_cutoff}')
df_filtered.shape

In [None]:
df_filtered.hist(column='word_count', bins=20, )

#### Court Types

In [None]:
# Look into the different courts that are present, statistics
court_counts = df_filtered['court'].value_counts()
court_counts[court_counts>500]

Types of courts seem to include:
- Supreme Court of United States
- United States Court of Appleas (for the ...th Circuit)
- United States Customs Court
- Supreme Court of [State]
- Appellate Court / Court of Appeals / Court of Errors and Appeals
- Superior Court
- New York Supreme Court, General Term; New York Supreme Court, New York Supreme Court, Appellate Division
- Court of Claims
- United States Board of Tax Appeals
- District Court
- Other





In [None]:
# Omitted New York because of the ambiguity with the city
state_names = [
    "Alaska", "Alabama", "Arkansas", "American Samoa", "Arizona", "California", "Colorado", "Connecticut", 
    "District of Columbia", "Washington D.C.", "Delaware", "Florida", "Georgia", "Guam", "Hawaii", "Iowa", 
    "Idaho", "Illinois", "Indiana", "Kansas", "Kentucky", "Louisiana", "Massachusetts", "Maryland", "Maine", 
    "Michigan", "Minnesota", "Missouri", "Mississippi", "Montana", "North Carolina", "North Dakota", 
    "Nebraska", "New Hampshire", "New Jersey", "New Mexico", "Nevada", "Ohio", "Oklahoma", 
    "Oregon", "Pennsylvania", "Puerto Rico", "Rhode Island", "South Carolina", "South Dakota", "Tennessee", 
    "Texas", "Utah", "Virginia", "Virgin Islands", "Vermont", "Washington", "Wisconsin", "West Virginia", "Wyoming"
]

def get_court_type(court):
    court = court.lower()
    if 'united states' in court:
        if 'supreme' in court:
            return 'Supreme Court of the United States'
        if any(['appeal' in court, 'appellate' in court]):
            return 'Federal Court of Appeals'
        else:
            return 'Federal Court'
    if any(['appeal' in court, 'appellate' in court, 'commonwealth court of pennsylvania' in court]):
        if any([state.lower() in court for state in state_names]):
            return 'State Court of Appeals'
        elif any(['federal' in court, 'united states' in court, 'u.s.' in court]):
            return 'Federal Court of Appeals'
        else:
            return 'Court of Appeals'
    if 'supreme' in court:
        return 'State Supreme Court'
    if 'district court' in court:
        return 'District Court'
    if 'superior' in court:
        return 'Superior Court'
    if 'court of claims' in court:
        return 'Court of Claims'
    if 'tax' in court:
        return 'Tax Court'
    if 'customs' in court:
        return 'Customs Court'
    if 'common pleas' in court:
        return 'Common Pleas Court'
    if 'circuit court' in court:
        return 'Circuit Court'
    if 'bankruptcy' in court:
        return 'Bankruptcy Court'
    if 'chancery' in court:
        return 'Chancery Court'
    if 'city court' in court:
        return 'City Court'
    if 'county court' in court:
        return 'County Court'
    return 'Other'

df_filtered.loc[:, 'court_type'] = df_filtered.loc[:, 'court'].apply(get_court_type)

In [None]:
df_courttype_counts = df_filtered.loc[:, 'court_type'].value_counts()
print(df_courttype_counts)
df_courttype_counts.plot.hist(title='Court Type Counts', edgecolor='black')

# Look into state vs federal appeals / district courts

#### Jurisdictions

In [None]:
# Look into the different jurisdictions that are present, statistics
jurisdiction_counts = df_filtered.loc[:, 'jurisdiction'].value_counts().sort_values(ascending=False)
jurisdiction_counts

In [None]:
jurisdiction_counts.plot.hist(title='Jurisdiction Counts', bins=15, edgecolor='black')

#### Balanced dataset creation

In [None]:
df_filtered.loc[:, 'court_type_jurisdiction'] = df_filtered.apply(lambda x: f"{x['court_type']} - {x['jurisdiction']}", axis=1)
df_courttype_jurisdiction_counts = df_filtered.loc[:, 'court_type_jurisdiction'].value_counts()
df_courttype_jurisdiction_counts.plot.hist(title='Court Type - Jurisdiction Counts', bins=15, edgecolor='black')

In [None]:
# Zoomed in histogram
df_courttype_jurisdiction_counts.loc[df_courttype_jurisdiction_counts < 2500].plot.hist(title='Court Type - Jurisdiction Counts (Zoomed In)', bins=20, edgecolor='black')

In [None]:
df_courttype_jurisdiction_counts.head()

In [None]:
# Exploring cutoffs
cutoffs = [10, 50, 100, 250, 500]
# Scaler used to estimate number of raw counts in full dataset
total_size = 4_284_276
scaler = total_size / df_filtered.shape[0]


for cutoff in cutoffs:
    print(f"Number of combos over {cutoff}: {df_courttype_jurisdiction_counts.loc[df_courttype_jurisdiction_counts > cutoff].shape[0]}")
    print(f"Number of combos under {cutoff}: {df_courttype_jurisdiction_counts.loc[df_courttype_jurisdiction_counts <= cutoff].shape[0]}")
    print(f"Estimated raw count above  {cutoff}: {int(round(df_courttype_jurisdiction_counts.loc[df_courttype_jurisdiction_counts > cutoff].sum() * scaler, -3))}")
    print(f"Estimated raw count below {cutoff}: {int(round(df_courttype_jurisdiction_counts.loc[df_courttype_jurisdiction_counts <= cutoff].sum() * scaler, -3))}")
    print(f"Estimated raw percent below {cutoff}: {round(df_courttype_jurisdiction_counts.loc[df_courttype_jurisdiction_counts <= cutoff].sum() * scaler / total_size * 100, 2)}")
    print('---------------------------------------------')

Planning on using a cutoff of 50 as the "other" category in my sampler.

In [None]:
# Save the list of court type - jurisdiction combos with more than 10 counts
main = list(df_courttype_jurisdiction_counts[df_courttype_jurisdiction_counts > 10].index)
with open('../data/main_courttype_jurisdiction_combos.txt', 'w') as f:
    for item in main:
        f.write(f"{item}\n")

In [None]:
courttype_jurisdictions = main + ['Other']
len(courttype_jurisdictions)

# Make balanced, stratified sample of full data

In [None]:
ds = load_dataset('free-law/Caselaw_Access_Project', split='train')

In [None]:
ds = ds.remove_columns(['first_page', 'last_page', 'volume', 'last_updated', 'provenance', 'judges', 'parties', 'head_matter', 'char_count'])

In [None]:
ds = ds.filter(lambda x: x['word_count'] != '')

In [None]:
features = ds.features
new_features = features.copy()
new_features['word_count'] = Value("int64")
ds = ds.cast(new_features)

In [None]:
ds = ds.filter(lambda x: x['word_count'] > floor_cutoff and x['word_count'] < ceiling_cutoff)

In [178]:
ds.column_names

['id',
 'name',
 'name_abbreviation',
 'decision_date',
 'docket_number',
 'citations',
 'reporter',
 'court',
 'jurisdiction',
 'word_count',
 'text']

In [None]:
with open('../data/main_courttype_jurisdiction_combos.txt', 'r') as f:
    main = f.read().splitlines()

def get_courttype_jurisdiction(example):
    court = example['court']
    jurisdiction = example['jurisdiction']
    court_type = get_court_type(court)
    combo = f"{court_type} - {jurisdiction}"
    if combo not in main:
        return 'Other'
    return combo

ds = ds.map(lambda x: {'court_type_jurisdiction': get_courttype_jurisdiction(x)}, num_proc=4)

In [None]:
court_type_jurisdictions_unique = list(set(ds['court_type_jurisdiction']))
len(court_type_jurisdictions_unique)

142

In [None]:
new_features = ds.features.copy()
new_features['court_type_jurisdiction'] = ClassLabel(names=court_type_jurisdictions_unique)
ds = ds.cast(new_features, num_proc=4)

In [None]:
# Shuffle the dataset, but write to disk using flatten_indices to speed up
ds = ds.shuffle(seed=42)

In [None]:
ds = ds.flatten_indices(num_proc=4)

In [None]:
ds.shape

In [None]:
from collections import defaultdict

sample_target = 500_000 # Target size of sampled dataset
class_target = sample_target // len(court_type_jurisdictions_unique) # Target size of each class

samples_dict = defaultdict(int)
sampled_indices = []

def check_sample(cls, id):
    if samples_dict[cls] < class_target:
        samples_dict[cls] += 1
        sampled_indices.append(id)

ds.map(lambda x: check_sample(x['court_type_jurisdiction'], x['id']), num_proc=4

IterableDataset({
    features: Unknown,
    num_shards: 58
})

In [None]:
balanced_ds = ds.filter(lambda x: x['id'] in sampled_indices, num_proc=4)

In [None]:
balanced_ds.shape
balanced_ds.to_parquet('../data/caselaw_balanced_1.parquet')

AttributeError: 'IterableDataset' object has no attribute 'to_parquet'

## Load and analyze new set

In [134]:
# Analyze by court types
balanced_df = pd.read_parquet('../data/caselaw_balanced_1.parquet')
balanced_df['court_type'] = balanced_df.court.apply(get_court_type)
balanced_df.court_type.value_counts()

court_type
State Supreme Court                   170731
State Court of Appeals                133008
Superior Court                         26587
Other                                  14334
Circuit Court                          11437
Common Pleas Court                     11316
Chancery Court                          8819
Court of Appeals                        8224
Court of Claims                         7703
Federal Court of Appeals                7063
County Court                            6912
Federal Court                           4021
City Court                              3569
Supreme Court of the United States      3521
District Court                          3375
Tax Court                               2734
Name: count, dtype: int64

In [135]:
# Analyze by jurisdictions
balanced_df.jurisdiction.value_counts()

jurisdiction
New York                    30676
Pennsylvania                20086
New Jersey                  17388
Ohio                        16721
Connecticut                 12049
United States               11087
Massachusetts               10772
Virginia                    10768
Illinois                    10617
Missouri                    10563
Florida                     10445
Delaware                    10236
South Carolina               8735
Oregon                       8279
District of Columbia         8032
Indiana                      7694
Oklahoma                     7529
Tennessee                    7209
California                   7200
North Carolina               7164
Georgia                      7143
Michigan                     7134
Kentucky                     7118
Louisiana                    7118
Arkansas                     7112
Mississippi                  7084
Arizona                      7060
Colorado                     7051
Texas                        7050
W

In [None]:
from random import choices
from collections import defaultdict

def down_sampler(df, target_size, column, column_value, group_by_column):
    """ Down samples the dataframe to the target size for a specific column value. 
        Prob of removal is proportional to size of other column values.
        
        Args:
            df: DataFrame to downsample
            target_size: Target size of downsampled dataframe
            column: Column to filter on
            column_value: Value to filter on
            group_by_column: Column to group by for downsampling
    """
    df1 = df.loc[df[column] == column_value, :].sample(frac=1, random_state=42)
    unique_values = list(df1[group_by_column].unique())
    value_counter = [df1.loc[df1[group_by_column] == val].shape[0] for val in unique_values]
    num_removes = df1.shape[0] - target_size
    
    removal_counts = defaultdict(int)
    while num_removes > 0:
        idx = choices(range(len(value_counter)), weights=value_counter)[0]
        value_counter[idx] -= 1
        removal_counts[unique_values[idx]] += 1
        num_removes -= 1
        
    # Remove samples by index
    for val, count in removal_counts.items():
        drop_indices = df1.loc[df1[group_by_column] == val].index[:count]
        df1 = df1.drop(drop_indices)
        
    # Recombine downsampled dataframe with original
    df_remaining = df.loc[df[column] != column_value, :]
    df_downsampled = pd.concat([df_remaining, df1])
    
    return df_downsampled

def down_sample_looper(df, target_size, column, column_values, group_by_column):
    """ Loops through values of a column and downsamples each one to the target size. """
    df_downsampled = df.copy()
    for value in column_values:
        if df_downsampled.loc[df_downsampled[column] == value].shape[0] > target_size:
            df_downsampled = down_sampler(df_downsampled, target_size, column, value, group_by_column)
    return df_downsampled


In [148]:
max_jurisdiction_size = 10_000
jurisdiction_list = list(jur for jur in balanced_df.jurisdiction.unique() if balanced_df.loc[balanced_df.jurisdiction == jur].shape[0] > max_jurisdiction_size)
downsampled_jurisdiction_df = down_sample_looper(balanced_df, max_jurisdiction_size, 'jurisdiction', jurisdiction_list, 'court_type')
downsampled_jurisdiction_df.jurisdiction.value_counts()

jurisdiction
Delaware                    10000
Pennsylvania                10000
United States               10000
Missouri                    10000
New York                    10000
Virginia                    10000
Florida                     10000
New Jersey                  10000
Massachusetts               10000
Ohio                        10000
Illinois                    10000
Connecticut                 10000
South Carolina               8735
Oregon                       8279
District of Columbia         8032
Indiana                      7694
Oklahoma                     7529
Tennessee                    7209
California                   7200
North Carolina               7164
Georgia                      7143
Michigan                     7134
Kentucky                     7118
Louisiana                    7118
Arkansas                     7112
Mississippi                  7084
Arizona                      7060
Colorado                     7051
Texas                        7050
W

Seems a lot more reasonable in terms of jurisdiction balance. Let's look at court type.

In [152]:
downsampled_jurisdiction_df.shape[0]

371946

In [151]:
downsampled_jurisdiction_df.court_type.value_counts()

court_type
State Supreme Court                   161933
State Court of Appeals                126681
Superior Court                         19709
Circuit Court                           9435
Other                                   8835
Federal Court of Appeals                6711
Court of Claims                         5914
Common Pleas Court                      5662
Chancery Court                          5591
Court of Appeals                        5587
Federal Court                           3638
Supreme Court of the United States      3197
County Court                            2972
District Court                          2538
Tax Court                               2336
City Court                              1207
Name: count, dtype: int64

Still very unbalanced in terms of court type. Let's downsample again.

In [167]:
downsampled_court_type_df.jurisdiction.value_counts()

jurisdiction
United States               9992
New York                    9191
Pennsylvania                7721
Delaware                    7613
New Jersey                  7374
Ohio                        7252
Connecticut                 6192
Massachusetts               5768
Virginia                    5720
Illinois                    5696
Missouri                    5666
Florida                     5658
District of Columbia        5244
South Carolina              4181
New Hampshire               3900
Oregon                      3767
Indiana                     3172
West Virginia               3163
Oklahoma                    2976
Alaska                      2792
California                  2668
Tennessee                   2665
Georgia                     2611
Kentucky                    2587
Michigan                    2575
Arkansas                    2573
Minnesota                   2524
Texas                       2521
Colorado                    2514
North Carolina              25

In [160]:
# Impose limit of 70,000 samples per court type
max_court_type_size = 50_000
court_type_list = list(court for court in downsampled_jurisdiction_df.court_type.unique() if downsampled_jurisdiction_df.loc[downsampled_jurisdiction_df.court_type == court].shape[0] > max_court_type_size)
downsampled_court_type_df = down_sample_looper(downsampled_jurisdiction_df, max_court_type_size, 'court_type', court_type_list, 'jurisdiction')
downsampled_court_type_df.court_type.value_counts()

court_type
State Court of Appeals                50000
State Supreme Court                   50000
Superior Court                        19709
Circuit Court                          9435
Other                                  8835
Federal Court of Appeals               6711
Court of Claims                        5914
Common Pleas Court                     5662
Chancery Court                         5591
Court of Appeals                       5587
Federal Court                          3638
Supreme Court of the United States     3197
County Court                           2972
District Court                         2538
Tax Court                              2336
City Court                             1207
Name: count, dtype: int64

In [166]:
downsampled_court_type_df.shape[0]

183332

Calling it there.

In [168]:
# Save the balanced dataset
downsampled_court_type_df.to_parquet('../data/caselaw_downsampled.parquet')