In [1]:
import json
import pandas as pd
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
# Read the txt file
with open('data/TRDataChallenge2023.txt', 'r', encoding='utf-8') as file:
    content = file.read()

print(f"File size: {len(content)} characters")

File size: 331936506 characters


In [None]:
# Parse JSON dictionaries from the text file
data = []
lines = content.strip().split('\n')

for i, line in enumerate(lines):
    if line.strip():
        try:
            json_obj = json.loads(line)
            data.append(json_obj)
        except json.JSONDecodeError as e:
            print(f"Error parsing line {i+1}: {e}")


print(f"Successfully parsed {len(data)} JSON objects")

Successfully parsed 18000 JSON objects


# Question1 : Data analysis

In [4]:
# 2. Structure analysis
print("DATA STRUCTURE ANALYSIS")

if data:
    # to get an idea of the data:
        # print("Sample JSON object:")
        # print(json.dumps(data[0], indent=2))
    print("Keys in the JSON objects:")
    all_keys = set()
    for obj in data:
        if isinstance(obj, dict):
            all_keys.update(obj.keys())
    
    for key in sorted(all_keys):
        print(f"- {key}")

DATA STRUCTURE ANALYSIS
Keys in the JSON objects:
- documentId
- postures
- sections


In [5]:
# 3. Convert to DataFrame for easier analysis
print("DATAFRAME CONVERSION")
if data and isinstance(data[0], dict):
    df = pd.DataFrame(data)
    print(f"DataFrame shape: {df.shape}")
    print("\nDataFrame info:")
    df.info()
    print("\nFirst few rows:")
    print(df.head())
else:
    print("Data is not in a dictionary format")

DATAFRAME CONVERSION
DataFrame shape: (18000, 3)

DataFrame info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 18000 entries, 0 to 17999
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   documentId  18000 non-null  object
 1   postures    18000 non-null  object
 2   sections    18000 non-null  object
dtypes: object(3)
memory usage: 422.0+ KB

First few rows:
                          documentId  \
0  Ib4e590e0a55f11e8a5d58a2c8dcb28b5   
1  Ib06ab4d056a011e98c7a8e995225dbf9   
2  Iaa3e3390b93111e9ba33b03ae9101fb2   
3  I0d4dffc381b711e280719c3f0e80bdd0   
4  I82c7ef10d6d111e8aec5b23c3317c9c0   

                                            postures  \
0                                        [On Appeal]   
1  [Appellate Review, Sentencing or Penalty Phase...   
2          [Motion to Compel Arbitration, On Appeal]   
3     [On Appeal, Review of Administrative Decision]   
4                                        [On A

In [6]:
print("DESCRIPTIVE STATISTICS")


if not df.empty:
    print("Missingvalues:")
    print(df.isnull().sum())
    
    print("\nData types:")
    print(df.dtypes)
    
    print(f"\nDataset Overview:")
    print(f"- Total documents: {len(df)}")
    print(f"- Unique document IDs: {df['documentId'].nunique()}")
    
    # Analyze postures column
    print(f"\nPostures Analysis:")
    all_postures = []
    num_empty_postures = (df['postures'].apply(lambda x: isinstance(x, list) and len(x) == 0)).sum()
    print(f"Documents with empty postre lists: {num_empty_postures}")
            
    posture_counts = Counter(all_postures)
    print(f"Total unique postures: {len(posture_counts)}")
    print("Most commons postures:")
    for posture, count in posture_counts.most_common(10):
        print(f"  {posture}: {count}")
    
    # Analyze sections structure
    print(f"\nSections Analysis:")
    section_counts = []
    for sections in df['sections']:
        if isinstance(sections, list):
            section_counts.append(len(sections))
    
    if section_counts:
        print(f"Average sections per document: {np.mean(section_counts):.2f}")
        print(f"Minsections: {min(section_counts)}")
        print(f"Max sections: {max(section_counts)}")
        
    # Sample section analysis
    print(f"\nSample Section Structure:")
    if df['sections'].iloc[0] and isinstance(df['sections'].iloc[0], list):
        sample_section = df['sections'].iloc[0][0]
        if isinstance(sample_section, dict):
            print(f"Sction keys: {list(sample_section.keys())}")
            if 'paragraphs' in sample_section and isinstance(sample_section['paragraphs'], list):
                print(f"Paragrashs in first section: {len(sample_section['paragraphs'])}")
else:
    print("DataFrame empty")


DESCRIPTIVE STATISTICS
Missingvalues:
documentId    0
postures      0
sections      0
dtype: int64

Data types:
documentId    object
postures      object
sections      object
dtype: object

Dataset Overview:
- Total documents: 18000
- Unique document IDs: 18000

Postures Analysis:
Documents with empty postre lists: 923
Total unique postures: 0
Most commons postures:

Sections Analysis:
Average sections per document: 5.09
Minsections: 1
Max sections: 91

Sample Section Structure:
Sction keys: ['headtext', 'paragraphs']
Paragrashs in first section: 1


In [7]:

# 5. Comprehensive counts
print("DATASET COMPREHENSIVE COUNTS")
print("-" * 30)

# Count total documents
total_documents = len(df)
print(f"Total Documents: {total_documents:,}")

# Count unique postures across all documents
all_postures = []
for posture_list in df['postures']:
    if isinstance(posture_list, list):
        all_postures.extend(posture_list)

unique_postures = set(all_postures)
total_posture_instances = len(all_postures)
print(f"Total Posture Instances: {total_posture_instances:,}")
print(f"Unique Postures: {len(unique_postures):,}")

# Count total paragraphs across all documents
total_paragraphs = 0
for sections_list in df['sections']:
    if isinstance(sections_list, list):
        for section in sections_list:
            if isinstance(section, dict) and 'paragraphs' in section:
                if isinstance(section['paragraphs'], list):
                    total_paragraphs += len(section['paragraphs'])

print(f"Total Paragraphs: {total_paragraphs:,}")

avg_postures_per_doc = total_posture_instances / total_documents
avg_paragraphs_per_doc = total_paragraphs / total_documents

print(f"Average Postures per Document: {avg_postures_per_doc:.2f}")
print(f"Average Paragraphs per Document: {avg_paragraphs_per_doc:.2f}")

# Documents with multiple postures
multi_posture_docs_count = sum(1 for postures in df['postures'] if isinstance(postures, list) and len(postures) > 1)
print(f"Documents with Multiple Postures: {multi_posture_docs_count:,} ({(multi_posture_docs_count/total_documents)*100:.1f}%)")

# Detailed posture count distribution
print(f"DOCUMENTS BY NUMBER OF POSTURES:")
postures_per_document = []
for postures in df['postures']:
    if isinstance(postures, list):
        postures_per_document.append(len(postures))
    else:
        postures_per_document.append(0)

posture_distribution = Counter(postures_per_document)
print(f"{'Postures':<12} {'Documents':<12} {'Percentage':<12}")


for num_postures in sorted(posture_distribution.keys()):
    count = posture_distribution[num_postures]
    percentage = (count / total_documents) * 100
    posture_label = f"{num_postures} posture{'s' if num_postures != 1 else ''}"
    print(f"{posture_label:<12} {count:<12,} {percentage:<11.1f}%")

single_posture_count = posture_distribution.get(1, 0)
multi_posture_total = sum(count for postures, count in posture_distribution.items() if postures > 1)
print(f"Single posture documents: {single_posture_count:,} ({(single_posture_count/total_documents)*100:.1f}%)")
print(f"Multi-posture documents: {multi_posture_total:,} ({(multi_posture_total/total_documents)*100:.1f}%)")
print(f"Maximum postures in one document: {max(postures_per_document)}")

# Most common postures (top 10)
posture_counts = Counter(all_postures)
print(f"TOP 10 MOST COMMON POSTURES:")
for i, (posture, count) in enumerate(posture_counts.most_common(10), 1):
    percentage = (count / total_posture_instances) * 100
    print(f"{i:2d}. {posture}: {count:,} ({percentage:.1f}%)")


# Find documents with multiple postures
multi_posture_docs = []
single_posture_docs = []

for idx, postures in enumerate(df['postures']):
    if isinstance(postures, list):
        if len(postures) > 1:
            multi_posture_docs.append({
                'index': idx,
                'documentId': df.iloc[idx]['documentId'],
                'postures': postures,
                'num_postures': len(postures)
            })
        elif len(postures) == 1:
            single_posture_docs.append({
                'index': idx,
                'postures': postures
            })

print(f" Multi-posture documents: {len(multi_posture_docs):,}")
print(f"Single-posture  documents: {len(single_posture_docs):,}")



# Distribution of number of postures per document
posture_count_distribution = Counter([doc['num_postures'] for doc in multi_posture_docs])
print(f"DISTRIBUTION OF POSTURES IN MULTI-POSTURE DOCUMENTS:")
for num_postures in sorted(posture_count_distribution.keys()):
    count = posture_count_distribution[num_postures]
    percentage = (count / len(multi_posture_docs)) * 100
    print(f"  {num_postures} postures: {count:,} documents ({percentage:.1f}%)")

# Analyze which postures appear in multi-posture vs single-posture documents
postures_in_multi = []
postures_in_single = []

for doc in multi_posture_docs:
    postures_in_multi.extend(doc['postures'])

for doc in single_posture_docs:
    postures_in_single.extend(doc['postures'])

# Count frequencies
multi_posture_counts = Counter(postures_in_multi)
single_posture_counts = Counter(postures_in_single)

# All unique postures
all_unique_postures = set(postures_in_multi + postures_in_single)

# Calculate statistics for each posture
posture_stats = {}
for posture in all_unique_postures:
    multi_freq = multi_posture_counts.get(posture, 0)
    single_freq = single_posture_counts.get(posture, 0)
    total_freq = multi_freq + single_freq
    
    multi_percentage = (multi_freq / total_freq) * 100 if total_freq > 0 else 0
    
    posture_stats[posture] = {
        'multi_freq': multi_freq,
        'single_freq': single_freq,
        'total_freq': total_freq,
        'multi_percentage': multi_percentage
    }

# Sort by multi-posture percentage (descending)
sorted_by_multi = sorted(posture_stats.items(), key=lambda x: x[1]['multi_percentage'], reverse=True)

print(f"TOP 15 POSTURES MOST INVOLVED IN MULTI-POSTURE DOCUMENTS:  ")
print(f"{'Rank':<4} {'Posture':<40} {'Multi':<6} {'Single':<7} {'Total':<7} {'Multi%':<8}")



# Most common posture combinations in multi-posture documents
print(f"MOST COMMON POSTURE COMBINATIONS:")
combinations = [tuple(sorted(doc['postures'])) for doc in multi_posture_docs]
combination_counts = Counter(combinations)

for i, (combo, count) in enumerate(combination_counts.most_common(10), 1):
    percentage = (count / len(multi_posture_docs)) * 100
    print(f"{i:2d}. {' + '.join(combo)}")
    print(f"Count: {count:,} ({percentage:.1f}% of multi-posture docs)")
    print()


DATASET COMPREHENSIVE COUNTS
------------------------------
Total Documents: 18,000
Total Posture Instances: 27,659
Unique Postures: 224
Total Paragraphs: 542,169
Average Postures per Document: 1.54
Average Paragraphs per Document: 30.12
Documents with Multiple Postures: 8,959 (49.8%)
DOCUMENTS BY NUMBER OF POSTURES:
Postures     Documents    Percentage  
0 postures   923          5.1        %
1 posture    8,118        45.1       %
2 postures   7,604        42.2       %
3 postures   1,129        6.3        %
4 postures   190          1.1        %
5 postures   32           0.2        %
6 postures   2            0.0        %
7 postures   2            0.0        %
Single posture documents: 8,118 (45.1%)
Multi-posture documents: 8,959 (49.8%)
Maximum postures in one document: 7
TOP 10 MOST COMMON POSTURES:
 1. On Appeal: 9,197 (33.3%)
 2. Appellate Review: 4,652 (16.8%)
 3. Review of Administrative Decision: 2,773 (10.0%)
 4. Motion to Dismiss: 1,679 (6.1%)
 5. Sentencing or Penalty Phase 

# Question2: model

In [13]:
# Create the additional field in order to concatenate the text present in multiple paragraphers.
# Add some separators to facilitate tokenization and embedding
print(f"File size: {len(content)} characters")

data = []
lines = content.strip().split('\n')

for i, line in enumerate(lines):
    if line.strip():
        try:
            json_obj = json.loads(line)

            # combine text from sections
            sections = json_obj.get("sections", [])
            text_parts = []

            for sec in sections:
                headtext = sec.get("headtext", "")
                if headtext:
                    text_parts.append("[HEAD] " + headtext)
                paragraphs = sec.get("paragraphs", [])
                for para in paragraphs:
                    text_parts.append("[PARA] " + para)

            # Join sections with [SEP]
            full_text = " [SEP] ".join(text_parts)

            # Save new keys to the object
            json_obj["combined_text"] = full_text

            # Append to data list
            data.append(json_obj)

        except json.JSONDecodeError as e:
            print(f"Error parsing line {i+1}: {e}")
            print(f"Line content: {line[:100]}...")

print(f"\nSuccessfully parsed {len(data)} JSON objects")

# Example: print first processed document
print("\nExample processed text:")
print(data[0]["combined_text"][:500])  # Print first 100 characters to check

File size: 331936506 characters

Successfully parsed 18000 JSON objects

Example processed text:
[PARA] Plaintiff Dwight Watson (“Husband”) appeals from the trial court’s equitable distribution order entered 28 February 2017. On appeal, plaintiff contends that the trial court erred in its classification, valuation, and distribution of the parties’ property and in granting defendant Gertha  Watson (“Wife”) an unequal distribution of martial property. Because the trial court’s findings of fact do not support its conclusions of law and because the distributional factors found by the trial cour


In [12]:

import pandas as pd

df = pd.DataFrame([
    {
        "documentId": d["documentId"],
        "labels": d["postures"],
        "text": d["combined_text"]
    }
    for d in data
])

print(df.head())

                          documentId  \
0  Ib4e590e0a55f11e8a5d58a2c8dcb28b5   
1  Ib06ab4d056a011e98c7a8e995225dbf9   
2  Iaa3e3390b93111e9ba33b03ae9101fb2   
3  I0d4dffc381b711e280719c3f0e80bdd0   
4  I82c7ef10d6d111e8aec5b23c3317c9c0   

                                              labels  \
0                                        [On Appeal]   
1  [Appellate Review, Sentencing or Penalty Phase...   
2          [Motion to Compel Arbitration, On Appeal]   
3     [On Appeal, Review of Administrative Decision]   
4                                        [On Appeal]   

                                                text  
0  [PARA] Plaintiff Dwight Watson (“Husband”) app...  
1  [PARA] After pleading guilty, William Jerome H...  
2  [PARA] Frederick Greene, the plaintiff below, ...  
3  [PARA] Appeal from an amended judgment of the ...  
4  [PARA] Order, Supreme Court, New York County (...  


In [None]:
# TAKE TOPk and keep only labels from the topk
#  Exclude documents that end up with no remaining labels
from collections import Counter

# data is a list of dicts, each with a 'postures' field
label_counter = Counter()
for entry in data:
    label_counter.update(entry["postures"])
# Print most common labels
print(label_counter.most_common(20))

k = 5  
top_k_labels = [label for label, _ in label_counter.most_common(k)]
print(f"top {5} labels :{top_k_labels}")
filtered_data = []
for entry in data:
    # Keep only labels that are in top k
    new_labels = [label for label in entry["postures"] if label in top_k_labels]

    if new_labels:
        # Update labels
        entry["postures"] = new_labels
        filtered_data.append(entry)

print(f"Remaining documents after filtering: {len(filtered_data)}")

# next steps: take care also of the kappa score?

[('On Appeal', 9197), ('Appellate Review', 4652), ('Review of Administrative Decision', 2773), ('Motion to Dismiss', 1679), ('Sentencing or Penalty Phase Motion or Objection', 1342), ('Trial or Guilt Phase Motion or Objection', 1097), ("Motion for Attorney's Fees", 612), ('Post-Trial Hearing Motion', 512), ('Motion for Preliminary Injunction', 364), ('Motion to Dismiss for Lack of Subject Matter Jurisdiction', 343), ('Motion to Compel Arbitration', 255), ('Motion for New Trial', 226), ('Petition to Terminate Parental Rights', 219), ('Motion for Judgment as a Matter of Law (JMOL)/Directed Verdict', 212), ('Motion for Reconsideration', 206), ('Motion to Dismiss for Lack of Personal Jurisdiction', 204), ('Motion for Costs', 168), ('Juvenile Delinquency Proceeding', 146), ('Motion for Default Judgment/Order of Default', 143), ('Motion to Dismiss for Lack of Standing', 137)]
top 5 labels :['On Appeal', 'Appellate Review', 'Review of Administrative Decision', 'Motion to Dismiss', 'Sentencing

## training phase

In [16]:
from sklearn.model_selection import train_test_split

# First split: train vs (val + test)
train_data, temp_data = train_test_split(
    filtered_data, 
    test_size=0.3,  # 30% goes to temp (val + test)
    random_state=42,
    shuffle=True
)

# Second split: val vs test (half-half from temp)
val_data, test_data = train_test_split(
    temp_data, 
    test_size=0.5,  # Split remaining 30% equally
    random_state=42,
    shuffle=True
)

# Check sizes
print(f"Train set size: {len(train_data)}")
print(f"val set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")

Train set size: 10962
val set size: 2349
Test set size: 2350


In [17]:
import random

# Fix random seed for reproducibility
random.seed(42)
# Define desired sample sizes
sample_sizes = [500,1000, 2000, 5000, 10000]
# Save all subsets
train_subsets = {}

for size in sample_sizes:
    # Check if wehave enough samples
    if len(train_data) >= size:
        subset = random.sample(train_data, size)
        train_subsets[size] = subset
        print(f"Created subset with {size} samples.")
    else:
        print(f"Not enough data to create subset with {size} samples (only {len(train_data)} avlable).")

# Example: look at first document text in 1k subset
print(train_subsets[500][0]["combined_text"][:500])

Created subset with 500 samples.
Created subset with 1000 samples.
Created subset with 2000 samples.
Created subset with 5000 samples.
Created subset with 10000 samples.
[PARA] Proceeding pursuant to CPLR article 78 (transferred to this Court by order of the Supreme Court, entered in Albany County) to review a determination of the Department of Health which found decedent ineligible for certain Medicaid benefits. [SEP] [PARA] In September 2003, Paul Hettinger (hereinafter decedent), a widower, executed a durable general power of attorney appointing Sharon Williams, an individual identified in the record as his cousin-in-law, as his attorney-in-fact.   The power 


## DistilRoBERTa


In [None]:
# Prepare label binarizer
from sklearn.preprocessing import MultiLabelBinarizer


label_list = top_k_labels
# Create binarizer
mlb = MultiLabelBinarizer(classes=label_list)
mlb.fit([label_list])


# Use the previously created 500-sample subset
subset_500 = train_subsets[500]

# Texts and labels from the subset
train_texts = [entry["combined_text"] for entry in subset_500]
train_labels_raw = [entry["postures"] for entry in subset_500]


val_texts = [entry["combined_text"] for entry in val_data]
val_labels_raw = [entry["postures"] for entry in val_data]

test_texts = [entry["combined_text"] for entry in test_data]
test_labels_raw = [entry["postures"] for entry in test_data]


In [19]:

train_labels = mlb.transform(train_labels_raw)
val_labels = mlb.transform(val_labels_raw)
test_labels = mlb.transform(test_labels_raw)

In [22]:
# tojkenizer

# tokenize text
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")

In [23]:
import torch
from torch.utils.data import Dataset

class LegalMLCDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        labels = self.labels[idx]
        inputs = self.tokenizer(
            text,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        item = {key: val.squeeze(0) for key, val in inputs.items()}
        item["labels"] = torch.tensor(labels, dtype=torch.float)
        return item

In [24]:
# Create dataset objects
train_dataset = LegalMLCDataset(train_texts, train_labels, tokenizer)
val_dataset = LegalMLCDataset(val_texts, val_labels, tokenizer)
test_dataset = LegalMLCDataset(test_texts, test_labels, tokenizer)

In [25]:
# Load model and modify for multi-label
from transformers import AutoModelForSequenceClassification

num_labels = len(label_list)
model = AutoModelForSequenceClassification.from_pretrained(
    "distilroberta-base",
    num_labels=num_labels,
    problem_type="multi_label_classification"
)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
# M1 confiuration device
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

model = model.to(device)

Using device: mps


In [27]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=2,      # reduce it because of laptop limittions
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    weight_decay=0.01,
    evaluation_strategy="epoch",        # Evaluate on dev set each epoch
    save_strategy="epoch",              # Save checkpoint each epoch
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",  # Use lowest validation loss
)



In [28]:
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np

def compute_metrics(pred):
    logits, labels = pred
    probs = 1 / (1 + np.exp(-logits))  # sigmoid
    preds = (probs >= 0.5).astype(int) # threshold

    metrics = {
        'micro/f1': f1_score(labels, preds, average='micro', zero_division=0),
        'macro/f1': f1_score(labels, preds, average='macro', zero_division=0),
        'micro/precision': precision_score(labels, preds, average='micro', zero_division=0),
        'macro/precision': precision_score(labels, preds, average='macro', zero_division=0),
        'micro/recall': recall_score(labels, preds, average='micro', zero_division=0),
        'macro/recall': recall_score(labels, preds, average='macro', zero_division=0),
    }
    return metrics

In [31]:
from transformers import Trainer

    
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)


### option A: subset of 500 documents and top5 labels

In [32]:
trainer.train()

Epoch,Training Loss,Validation Loss,Micro/f1,Macro/f1,Micro/precision,Macro/precision,Micro/recall,Macro/recall
1,No log,0.307804,0.728411,0.355044,0.827155,0.334795,0.650729,0.377927
2,0.317300,0.250385,0.801531,0.504835,0.86643,0.533214,0.745677,0.495168
3,0.317300,0.242956,0.816088,0.523792,0.885363,0.533953,0.756867,0.519173


TrainOutput(global_step=750, training_loss=0.2781613210042318, metrics={'train_runtime': 410.2974, 'train_samples_per_second': 3.656, 'train_steps_per_second': 1.828, 'total_flos': 198711728640000.0, 'train_loss': 0.2781613210042318, 'epoch': 3.0})

In [33]:

metrics = trainer.evaluate(test_dataset)
print(metrics)

{'eval_loss': 0.2164878249168396, 'eval_micro/f1': 0.8339106654512306, 'eval_macro/f1': 0.5349275550375603, 'eval_micro/precision': 0.9000393545848091, 'eval_macro/precision': 0.5353950156860148, 'eval_micro/recall': 0.7768342391304348, 'eval_macro/recall': 0.536347816782792, 'eval_runtime': 75.6638, 'eval_samples_per_second': 31.058, 'eval_steps_per_second': 15.529, 'epoch': 3.0}


### option B: subset of 1000 documents and top5 labels

In [36]:
# Use the previously created 1000-sample subset
subset_1000 = train_subsets[1000]

# Texts and labels from the subset
train_texts = [entry["combined_text"] for entry in subset_1000]
train_labels_raw = [entry["postures"] for entry in subset_1000]


val_texts = [entry["combined_text"] for entry in val_data]
val_labels_raw = [entry["postures"] for entry in val_data]

test_texts = [entry["combined_text"] for entry in test_data]
test_labels_raw = [entry["postures"] for entry in test_data]



train_labels = mlb.transform(train_labels_raw)
val_labels = mlb.transform(val_labels_raw)
test_labels = mlb.transform(test_labels_raw)


train_dataset = LegalMLCDataset(train_texts, train_labels, tokenizer)
val_dataset = LegalMLCDataset(val_texts, val_labels, tokenizer)
test_dataset = LegalMLCDataset(test_texts, test_labels, tokenizer)

In [37]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

In [38]:
trainer.train()
metrics = trainer.evaluate(test_dataset)
print(metrics)

Epoch,Training Loss,Validation Loss,Micro/f1,Macro/f1,Micro/precision,Macro/precision,Micro/recall,Macro/recall
1,0.2362,0.210849,0.844989,0.73006,0.848019,0.782337,0.84198,0.740142
2,0.1677,0.20586,0.860527,0.685403,0.898818,0.836328,0.825365,0.663118
3,0.1305,0.200388,0.876796,0.801184,0.873989,0.802925,0.87962,0.800582


{'eval_loss': 0.17552222311496735, 'eval_micro/f1': 0.8896388795140061, 'eval_macro/f1': 0.8136859754679941, 'eval_micro/precision': 0.8839704896042925, 'eval_macro/precision': 0.8080036994610944, 'eval_micro/recall': 0.8953804347826086, 'eval_macro/recall': 0.8196197054883976, 'eval_runtime': 75.2187, 'eval_samples_per_second': 31.242, 'eval_steps_per_second': 15.621, 'epoch': 3.0}


### option C: subset of 500 documents and top10 labels

In [40]:
k = 10
top_k_labels = [label for label, _ in label_counter.most_common(k)]
print(f"top {10} labels :{top_k_labels}")
filtered_data = []
for entry in data:
    # Keep only labels that are in top k
    new_labels = [label for label in entry["postures"] if label in top_k_labels]

    if new_labels:
        # Update labels
        entry["postures"] = new_labels
        filtered_data.append(entry)

print(f"Remaining documents after filtering: {len(filtered_data)}")


# Use the previously created 1000-sample subset
subset_500 = train_subsets[500]

# Texts and labels from the subset
train_texts = [entry["combined_text"] for entry in subset_1000]
train_labels_raw = [entry["postures"] for entry in subset_1000]


val_texts = [entry["combined_text"] for entry in val_data]
val_labels_raw = [entry["postures"] for entry in val_data]

test_texts = [entry["combined_text"] for entry in test_data]
test_labels_raw = [entry["postures"] for entry in test_data]



train_labels = mlb.transform(train_labels_raw)
val_labels = mlb.transform(val_labels_raw)
test_labels = mlb.transform(test_labels_raw)


train_dataset = LegalMLCDataset(train_texts, train_labels, tokenizer)
val_dataset = LegalMLCDataset(val_texts, val_labels, tokenizer)
test_dataset = LegalMLCDataset(test_texts, test_labels, tokenizer)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
metrics = trainer.evaluate(test_dataset)
print(metrics)

top 10 labels :['On Appeal', 'Appellate Review', 'Review of Administrative Decision', 'Motion to Dismiss', 'Sentencing or Penalty Phase Motion or Objection', 'Trial or Guilt Phase Motion or Objection', "Motion for Attorney's Fees", 'Post-Trial Hearing Motion', 'Motion for Preliminary Injunction', 'Motion to Dismiss for Lack of Subject Matter Jurisdiction']
Remaining documents after filtering: 16140


Epoch,Training Loss,Validation Loss,Micro/f1,Macro/f1,Micro/precision,Macro/precision,Micro/recall,Macro/recall
1,0.1324,0.213335,0.871021,0.803833,0.852092,0.774558,0.89081,0.838396
2,0.0905,0.258268,0.872913,0.773823,0.877097,0.813184,0.868769,0.752101
3,0.0655,0.251022,0.880777,0.802876,0.869997,0.79636,0.891828,0.810407


{'eval_loss': 0.19714035093784332, 'eval_micro/f1': 0.8770940454470061, 'eval_macro/f1': 0.8097453084242956, 'eval_micro/precision': 0.8570502431118314, 'eval_macro/precision': 0.7800796262236082, 'eval_micro/recall': 0.8980978260869565, 'eval_macro/recall': 0.846107740694551, 'eval_runtime': 75.5588, 'eval_samples_per_second': 31.102, 'eval_steps_per_second': 15.551, 'epoch': 3.0}
