In [1]:
!pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable


In [6]:
from bs4 import BeautifulSoup
import os
from glob import glob
import dask
import dask.array as da
import dask.bag as db
import dask.dataframe as dd
from dask import delayed
import pandas as pd
from distributed import Client
from dask_jobqueue import SLURMCluster
from IPython.display import display
import matplotlib.pyplot as plt
import time
import numpy as np
import pyarrow
from dask.diagnostics import ProgressBar
import time
import csv
from tqdm import tqdm
from tqdm.notebook import tqdm
tqdm.pandas()
import io
import sys
from dask.diagnostics import ProgressBar
from sklearn.model_selection import GroupShuffleSplit
import nltk
from nltk import word_tokenize
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)
csv.field_size_limit(sys.maxsize)


# Set LOCAL to True for single-machine execution while developing
# Set LOCAL to False for cluster execution
LOCAL = False


if LOCAL:
    # This line creates a single-machine dask client
    client = Client()
else:    
    # This line creates a SLURM cluster dask and dask client
    # Logging outputs will be stored in /scratch/{your-netid}
    
    cluster = SLURMCluster(
                           # Memory and core limits should be sufficient here
                           memory='64GB', cores=8,

                            # Ensure that Dask uses the correct version of Python on the cluster
                            python='/scratch/work/public/dask/{}/bin/python'.format(dask.__version__),                           
                           
                            # Place the output logs in an accessible location
                            job_extra=['--export=NONE --output=/scratch/{}/slurm-%j.out'.format(os.environ['SLURM_JOB_USER'])]
    )

    cluster.submit_command = 'slurm'
    cluster.scale(200) 

    display(cluster)
    client = Client(cluster)

display(client)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 33393 instead


Tab(children=(HTML(value='<div class="jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-OutputArea-outpu…

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://10.32.35.45:33393/status,

0,1
Dashboard: http://10.32.35.45:33393/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.32.35.45:46583,Workers: 0
Dashboard: http://10.32.35.45:33393/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [7]:
# Load in decision text
text_df = pd.read_csv('/vast/mcn8851/text-opinions-data-sc.csv') #Update File Path

In [9]:
# Load in cluster data and decision direction
cluster_df = pd.read_csv('/vast/mcn8851/opinions-cluster-data-sc.csv') #Update File Path

In [10]:
# Drop unnamed column
pd.reset_option('display.max_colwidth')
text_df = text_df.drop(text_df.columns[0], axis=1)

# Rename id in clusters data to cluster id for merging
cluster_df = cluster_df.rename(columns={"id": "cluster_id"})

# Merge dataframes on cluster id
data_df = text_df.merge(cluster_df, how='left', on='cluster_id')

In [11]:
# Choose subset of columns
filtered_df = data_df[['id', 'cluster_id','type','decision_text', 'date_filed', 'scdb_decision_direction']]

#Filter date filed
filtered_df = filtered_df[filtered_df['date_filed']>='1930-01-01'].reset_index(drop=True)


# Filter for specific types
types = ["010combined", "015unanimous", "020lead", "025plurality", "030concurrence", "035concurrenceinpart", "040dissent"]
data = filtered_df[filtered_df['type'].isin(types)]

#Flip dissent labels:
data.loc[(data['type'] == '040dissent') & (data['scdb_decision_direction'] == 1.0), 'scdb_decision_direction'] = 2.0
data.loc[(data['type'] == '040dissent') & (data['scdb_decision_direction'] == 2.0), 'scdb_decision_direction'] = 1.0

# Drop rows with null values for decision text and decision direction
data = data.dropna(subset=['decision_text'])
data = data.dropna(subset=['scdb_decision_direction'])

# Drop decisions directions that could not be discerned
data = data[data['scdb_decision_direction'] != 3.0]
data['scdb_decision_direction'] = data['scdb_decision_direction'].astype('category')
#data['scdb_decision_direction'] = data['scdb_decision_direction'].map({1.0: 0, 2.0: 1})
display(data)

Unnamed: 0,id,cluster_id,type,decision_text,date_filed,scdb_decision_direction
0,9420371,104708,020lead,\nMr. Justice Frankfurter\ndelivered the opini...,1949-06-27,1.0
1,9419099,103340,020lead,\nMr. Justice Roberts\ndelivered the opinion o...,1940-05-20,2.0
2,9419100,103340,040dissent,"\nMr. Justice Black, Mr. Justice Douglas, and ...",1940-05-20,1.0
3,103341,103341,010combined,\n\n \n *18\n \n MR. Justice Black\n \n\n...,1940-05-20,2.0
4,103342,103342,010combined,\n\n Mr. Justice Reed\n \n\n delivered the o...,1940-04-22,2.0
...,...,...,...,...,...,...
26877,105923,105923,010combined,\n360 U.S. 395 (1959)\nPITTSBURGH PLATE GLASS ...,1959-10-12,1.0
26878,1855418,1855418,010combined,\n\n \n *217\n \n Per Curiam.\n \n\n I\n...,2000-01-12,2.0
26879,2680438,2680438,010combined,\n\n Petitioner Fifth Third Bancorp maintains...,2014-06-25,1.0
26880,112890,112890,010combined,\n509 U.S. 86 (1993)\nHARPER et al.\nv.\nVIRGI...,1993-06-18,2.0


In [12]:
data.nunique()

id                         26425
cluster_id                 11204
type                           6
decision_text              26401
date_filed                  2924
scdb_decision_direction        2
dtype: int64

In [13]:
# Test set is 20% of the data
splitter = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=7)
train_temp_inds, test_inds = next(splitter.split(data, groups=data['cluster_id']))

# Further split train set into train and val: 75%-25% split
splitter = GroupShuffleSplit(n_splits=1, test_size=0.25, random_state=7)
train_inds, val_inds = next(splitter.split(data.iloc[train_temp_inds], groups=data.iloc[train_temp_inds]['cluster_id']))

train = data.iloc[train_temp_inds[train_inds]]
val = data.iloc[train_temp_inds[val_inds]]
test = data.iloc[test_inds]

In [14]:
# Checking that there is no data leakage 
cluster_ids_train = train['cluster_id'].tolist()
cluster_ids_val = val['cluster_id'].tolist()
cluster_ids_test = test['cluster_id'].tolist()

common_train_val = set(cluster_ids_train).intersection(cluster_ids_val)
if common_train_val:
    print("Elements in common")
else:
    print("No elements in common")
    
common_train_test = set(cluster_ids_train).intersection(cluster_ids_test)
if common_train_test:
    print("Elements in common")
else:
    print("No elements in common")
    
common_val_test = set(cluster_ids_val).intersection(cluster_ids_test)
if common_val_test:
    print("Elements in common")
else:
    print("No elements in common")


No elements in common
No elements in common
No elements in common


In [15]:
train.to_csv('/vast/mcn8851/sc-train.csv')
val.to_csv('/vast/mcn8851/sc-val.csv')
test.to_csv('/vast/mcn8851/sc-test.csv')