In [1]:
from bs4 import BeautifulSoup
import os
from glob import glob
import dask
import dask.array as da
import dask.bag as db
import dask.dataframe as dd
from dask import delayed
import pandas as pd
from distributed import Client
from dask_jobqueue import SLURMCluster
from IPython.display import display
import matplotlib.pyplot as plt
import time
import numpy as np
import pyarrow
from dask.diagnostics import ProgressBar
import time
import csv
from tqdm import tqdm
from tqdm.notebook import tqdm
tqdm.pandas()
import io
import sys
from dask.diagnostics import ProgressBar
from sklearn.model_selection import GroupShuffleSplit
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)
csv.field_size_limit(sys.maxsize)


# Set LOCAL to True for single-machine execution while developing
# Set LOCAL to False for cluster execution
LOCAL = False


if LOCAL:
    # This line creates a single-machine dask client
    client = Client()
else:    
    # This line creates a SLURM cluster dask and dask client
    # Logging outputs will be stored in /scratch/{your-netid}
    
    cluster = SLURMCluster(
                           # Memory and core limits should be sufficient here
                           memory='64GB', cores=8,

                            # Ensure that Dask uses the correct version of Python on the cluster
                            python='/scratch/work/public/dask/{}/bin/python'.format(dask.__version__),                           
                           
                            # Place the output logs in an accessible location
                            job_extra=['--export=NONE --output=/scratch/{}/slurm-%j.out'.format(os.environ['SLURM_JOB_USER'])]
    )

    cluster.submit_command = 'slurm'
    cluster.scale(200) 

    display(cluster)
    client = Client(cluster)

display(client)

  from distributed.utils import tmpfile


Tab(children=(HTML(value='<div class="jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-OutputArea-outpu…

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://10.32.35.126:8787/status,

0,1
Dashboard: http://10.32.35.126:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.32.35.126:39255,Workers: 0
Dashboard: http://10.32.35.126:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [2]:
# Load in decision text
text_df = pd.read_csv('/vast/amh9750/text-opinions-data-sc.csv')

In [3]:
# Load in cluster data and decision direction
cluster_df = pd.read_csv('/vast/amh9750/opinions-cluster-data-sc-new.csv')

In [4]:
# Drop unnamed column
pd.reset_option('display.max_colwidth')
text_df = text_df.drop(text_df.columns[0], axis=1)

# Rename id in clusters data to cluster id for merging
cluster_df = cluster_df.rename(columns={"id": "cluster_id"})

# Merge dataframes on cluster id
data_df = text_df.merge(cluster_df, how='left', on='cluster_id')

In [5]:
# Choose subset of columns
filtered_df = data_df[['id', 'cluster_id','type','decision_text', 'date_filed', 'scdb_decision_direction']]

# Filter for specific types
types = ["010combined", "015unanimous", "020lead", "025plurality", "030concurrence", "035concurrenceinpart", "040dissent"]
data = filtered_df[filtered_df['type'].isin(types)]

# Drop rows with null values for decision text and decision direction
data = data.dropna(subset=['decision_text'])
data = data.dropna(subset=['scdb_decision_direction'])

# Drop decisions directions that could not be discerned
data = data[data['scdb_decision_direction'] != 3.0]
data['scdb_decision_direction'] = data['scdb_decision_direction'].astype('category')
#data['scdb_decision_direction'] = data['scdb_decision_direction'].map({1.0: 0, 2.0: 1})
display(data)

Unnamed: 0,id,cluster_id,type,decision_text,date_filed,scdb_decision_direction
1,90814,90814,010combined,\n\n Mr. Chief Justice Waite\n \n\n delivere...,1882-11-06,1.0
4,90598,90598,010combined,\n\n Mr. Chief Justice Waite\n \n\n delivere...,1882-04-18,1.0
5,90599,90599,010combined,\n\n Mia. Justice Matthews\n \n\n delivered ...,1882-04-24,1.0
6,90600,90600,010combined,"\n\n Mr. Chief Justice Waite,\n \n\n after s...",1882-04-18,1.0
7,9417336,90602,020lead,"\nMr. Chief' Justice Waite,\nafter stating the...",1882-05-18,2.0
...,...,...,...,...,...,...
45502,91585,91585,010combined,\n117 U.S. 96 (1886)\nLEATHER MANUFACTURERS' B...,1886-03-01,2.0
45503,100202,100202,010combined,\n\n Mb. Chief Justice Taft\n \n\n delivered...,1923-05-07,1.0
45504,89804,89804,010combined,\n\n Mr. Justice Swayne\n \n\n delivered the...,1878-11-18,1.0
45505,101012,101012,010combined,\n\n Mr. Justice Brandéis\n \n\n delivered t...,1927-02-21,1.0


In [6]:
data.nunique()

id                         41332
cluster_id                 24386
type                           6
decision_text              41307
date_filed                  5406
scdb_decision_direction        2
dtype: int64

In [7]:
# Test set is 20% of the data
splitter = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=7)
train_temp_inds, test_inds = next(splitter.split(data, groups=data['cluster_id']))

# Further split train set into train and val: 75%-25% split
splitter = GroupShuffleSplit(n_splits=1, test_size=0.25, random_state=7)
train_inds, val_inds = next(splitter.split(data.iloc[train_temp_inds], groups=data.iloc[train_temp_inds]['cluster_id']))

train = data.iloc[train_temp_inds[train_inds]]
val = data.iloc[train_temp_inds[val_inds]]
test = data.iloc[test_inds]

In [11]:
# Checking that there is no data leakage 
cluster_ids_train = train['cluster_id'].tolist()
cluster_ids_val = val['cluster_id'].tolist()
cluster_ids_test = test['cluster_id'].tolist()

common_train_val = set(cluster_ids_train).intersection(cluster_ids_val)
if common_train_val:
    print("Elements in common")
else:
    print("No elements in common")
    
common_train_test = set(cluster_ids_train).intersection(cluster_ids_test)
if common_train_test:
    print("Elements in common")
else:
    print("No elements in common")
    
common_val_test = set(cluster_ids_val).intersection(cluster_ids_test)
if common_val_test:
    print("Elements in common")
else:
    print("No elements in common")


No elements in common
No elements in common
No elements in common


In [12]:
# train.to_csv('/vast/amh9750/sc-train.csv')
# val.to_csv('/vast/amh9750/sc-val.csv')
# test.to_csv('/vast/amh9750/sc-test.csv')