In [1]:
from echr import *
from classifier import *
import pandas as pd
import matplotlib.pyplot as plt
from transformers import LongformerTokenizer
import numpy as np
import re
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
def summarize_split(train_df, test_df, name="Dataset"):
    total = len(train_df) + len(test_df)
    test_pct = (len(test_df) / total) * 100 if total > 0 else 0
    lbl_dst_test = (len(test_df[test_df['violation']==1]) / len(test_df)) * 100
    lbl_dst_train = (len(train_df[train_df['violation']==1]) / len(train_df)) * 100
    print(f"📊 {name} split summary:")
    print(f"- Train size: {len(train_df)}")
    print(f"- Test size:  {len(test_df)}")
    print(f"- Test %:     {test_pct:.2f}% of total ({total} cases)")
    print(f"- Test lbl:    {lbl_dst_test:.2f}% violation")
    print(f"- Train lbl:    {lbl_dst_train:.2f}% violation\n")

In [3]:
json_path = 'datasets/original/cases_04_2024.json'
article = '6' 
part = 'facts'
split_year = 2015 # year on which to split training/test

In [None]:
chamber_df = create_dataset(json_path, article, part, 'Chamber')
grand_chamber_df = create_dataset(json_path, article, part, 'Grand Chamber')

In [None]:
# Grand Chamber
grand_chamber_df_test = grand_chamber_df[grand_chamber_df['year'] >= split_year]
grand_chamber_df_train = grand_chamber_df[grand_chamber_df['year'] < split_year]
summarize_split(grand_chamber_df_train, grand_chamber_df_test, name="Grand Chamber")

#Chamber
chamber_df_test = chamber_df[chamber_df['year'] >= split_year]
chamber_df_train = chamber_df[chamber_df['year'] < split_year]
summarize_split(chamber_df_train, chamber_df_test, name="Chamber")

In [6]:
chamber_df_test.head(10)

Unnamed: 0,unique_id,id,body,text,year,violation,sample_weight,judgment_info
1,1,001-153349,Chamber,"5.The applicants were born in 1961, 1964 and 1...",2015,1,1.0,"FOR THESE REASONS, THE COURT UNANIMOUSLY\n1. ..."
2,2,001-153349,Chamber,"5.The applicants were born in 1961, 1964 and 1...",2015,1,1.0,"FOR THESE REASONS, THE COURT UNANIMOUSLY\n1. ..."
7,7,001-164199,Chamber,5.The applicant was born in 1939 and lives in ...,2016,1,1.0,"FOR THESE REASONS, THE COURT, UNANIMOUSLY,\n1...."
17,17,001-200866,Chamber,6.Details of the applicants are set out in the...,2020,1,1.0,"FOR THESE REASONS, THE COURT, UNANIMOUSLY,\nDE..."
18,18,001-200866,Chamber,6.Details of the applicants are set out in the...,2020,1,1.0,"FOR THESE REASONS, THE COURT, UNANIMOUSLY,\nDE..."
26,26,001-170633,Chamber,5.The first applicant was born in 1933 and liv...,2017,0,1.0,"FOR THESE REASONS, THE COURT, UNANIMOUSLY,\n....."
29,29,001-183120,Chamber,5.The applicant was born in 1975. He is curren...,2018,1,1.0,"FOR THESE REASONS, THE COURT UNANIMOUSLY,\n1. ..."
37,37,001-202539,Chamber,1.The applicant was born in 1937 and lives in ...,2020,1,1.0,"FOR THESE REASONS, THE COURT, UNANIMOUSLY,\nDE..."
41,41,001-157537,Chamber,5.The applicant was born in 1949 and lives in ...,2015,1,1.0,"FOR THESE REASONS, THE COURT, UNANIMOUSLY,\n1...."
44,44,001-200351,Chamber,5.The applicant was born in 1960 and lives in ...,2020,1,1.0,"FOR THESE REASONS, THE COURT, UNANIMOUSLY,\nDE..."


In [None]:
# Balance multiple_datasets
balanced_sets_chamber = generate_balanced_subsets(chamber_df_train, n=7, random_seed=42)
balanced_sets_grand_chamber = generate_balanced_subsets(grand_chamber_df_train, n=7, random_seed=42)

In [None]:
# Store csvs
grand_chamber_df_test.to_csv('datasets/test_grand_chamber.csv', index=False)
chamber_df_test.to_csv('datasets/test_chamber.csv', index=False)
for idx, (c_df, gc_df) in enumerate(zip(balanced_sets_chamber, balanced_sets_grand_chamber)):
    c_df.to_csv('datasets/train_chamber_'+str(idx)+'.csv', index=False)
    gc_df.to_csv('datasets/train_grand_chamber_'+str(idx)+'.csv', index=False)

### Load data

In [19]:
all_datasets = {
    'test': {
        'chamber': pd.read_csv('datasets/test_chamber.csv'),
        'grand_chamber': pd.read_csv('datasets/test_grand_chamber.csv'),
    },
    'train': {
        'chamber': {idx: pd.read_csv('datasets/train_chamber_'+str(idx)+'.csv') for idx in range(0,7)},    
        'grand_chamber': {idx: pd.read_csv('datasets/train_grand_chamber_'+str(idx)+'.csv') for idx in range(0,7)},    
    }
}

### Store embeddings

In [17]:
import os
import pandas as pd
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_NAME = 'allenai/longformer-base-4096'

# Load Longformer model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
model.eval()

# Helper: Get mean pooled embedding for a single text
@torch.no_grad()
def get_longformer_embedding(text):
    tokens = tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=4096,
        return_tensors='pt'
    ).to(device)
    
    output = model(**tokens).last_hidden_state  # shape: [1, seq_len, hidden_size]
    attention_mask = tokens['attention_mask'].unsqueeze(-1)  # [1, seq_len, 1]
    masked_output = output * attention_mask  # Zero out paddings
    summed = masked_output.sum(dim=1)
    count = attention_mask.sum(dim=1).clamp(min=1e-9)
    mean_pooled = summed / count  # [1, hidden_size]
    
    return mean_pooled.squeeze(0).cpu().numpy()

# Process and store embeddings
def compute_and_store_chamber_embeddings(all_datasets, output_dir='datasets/embeddings/'):
    os.makedirs(output_dir, exist_ok=True)
    tqdm.pandas()

    for body in all_datasets['train'].keys():
        for split_id, df in all_datasets['train'][body].items():
            print(f"Processing {body} split {split_id}...")
            
            # Calculate embeddings
            df = df.copy()
            df['embedding'] = df['text'].progress_apply(get_longformer_embedding)
            
            # Save as pickle to preserve vectors
            df.to_pickle(os.path.join(output_dir, f'{body}_split_{split_id}.pkl'))

In [18]:
compute_and_store_chamber_embeddings(all_datasets)

Processing chamber split 0...


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [02:20<00:00,  8.98it/s]


Processing chamber split 1...


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [02:22<00:00,  8.86it/s]


Processing chamber split 2...


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [02:27<00:00,  8.52it/s]


Processing chamber split 3...


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [02:28<00:00,  8.51it/s]


Processing chamber split 4...


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [02:28<00:00,  8.49it/s]


Processing chamber split 5...


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [02:31<00:00,  8.34it/s]


Processing chamber split 6...


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [02:32<00:00,  8.28it/s]


Processing grand_chamber split 0...


100%|████████████████████████████████████████████████████████████████████████████████| 116/116 [00:14<00:00,  7.97it/s]


Processing grand_chamber split 1...


100%|████████████████████████████████████████████████████████████████████████████████| 116/116 [00:14<00:00,  7.99it/s]


Processing grand_chamber split 2...


100%|████████████████████████████████████████████████████████████████████████████████| 116/116 [00:14<00:00,  8.07it/s]


Processing grand_chamber split 3...


100%|████████████████████████████████████████████████████████████████████████████████| 116/116 [00:14<00:00,  7.78it/s]


Processing grand_chamber split 4...


100%|████████████████████████████████████████████████████████████████████████████████| 116/116 [00:14<00:00,  7.93it/s]


Processing grand_chamber split 5...


100%|████████████████████████████████████████████████████████████████████████████████| 116/116 [00:14<00:00,  7.92it/s]


Processing grand_chamber split 6...


100%|████████████████████████████████████████████████████████████████████████████████| 116/116 [00:14<00:00,  7.91it/s]


### Calculate and store propensities

In [33]:
import os
import pickle
import pandas as pd
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

def load_pickled_df(split_id, kind, folder='datasets/embeddings'):
    path = os.path.join(folder, f"{kind}_split_{split_id}.pkl")
    with open(path, 'rb') as f:
        return pickle.load(f)

def train_propensity_model(reference_df, target_df, score_on='target'):
    """
    Train a logistic regression model to distinguish between reference_df (label=0)
    and target_df (label=1). Returns the propensity scores for the dataset specified
    in 'score_on' ('reference' or 'target').
    """
    reference_df = reference_df.copy()
    target_df = target_df.copy()
    reference_df['target'] = 0
    target_df['target'] = 1

    combined_df = pd.concat([reference_df, target_df], ignore_index=True)
    X = np.vstack(combined_df['embedding'].values)
    y = combined_df['target'].values

    model = Pipeline([
        ('scaler', StandardScaler()),
        ('clf', LogisticRegression(solver='lbfgs', max_iter=1000))
    ])
    model.fit(X, y)

    if score_on == 'reference':
        X_score = np.vstack(reference_df['embedding'].values)
    elif score_on == 'target':
        X_score = np.vstack(target_df['embedding'].values)
    else:
        raise ValueError("score_on must be either 'reference' or 'target'")

    propensities = model.predict_proba(X_score)[:, 1]
    return propensities

def compute_and_save_propensities(split_id, input_folder='datasets/embeddings', output_folder='datasets/propensities'):
    """
    Computes propensity scores of grand chamber cases (target) based on chamber cases (reference).
    """
    # Load pickled dataframes with embeddings
    chamber_df = load_pickled_df(split_id, 'chamber', input_folder)
    grand_chamber_df = load_pickled_df(split_id, 'grand_chamber', input_folder)

    # Compute propensities for grand chamber cases
    propensities = train_propensity_model(chamber_df, grand_chamber_df, score_on='target')
    grand_chamber_df['propensity_score'] = propensities

    # Compute inverse propensity weights
    grand_chamber_df['sample_weight'] = 1.0 / grand_chamber_df['propensity_score']

    # Replace infinite/very large weights with cap
    grand_chamber_df['sample_weight'] = grand_chamber_df['sample_weight'].clip(upper=20.0)

    # Save updated grand_chamber_df
    os.makedirs(output_folder, exist_ok=True)
    out_path = os.path.join(output_folder, f'grand_chamber_with_propensity_split_{split_id}.pkl')
    with open(out_path, 'wb') as f:
        pickle.dump(grand_chamber_df, f)
    print(f"✅ Stored: {out_path}")


In [34]:
for split_id in range(0, 7):
    compute_and_save_propensities(split_id=split_id)

✅ Stored: datasets/balanced/propensities\grand_chamber_with_propensity_split_0.pkl
✅ Stored: datasets/balanced/propensities\grand_chamber_with_propensity_split_1.pkl
✅ Stored: datasets/balanced/propensities\grand_chamber_with_propensity_split_2.pkl
✅ Stored: datasets/balanced/propensities\grand_chamber_with_propensity_split_3.pkl
✅ Stored: datasets/balanced/propensities\grand_chamber_with_propensity_split_4.pkl
✅ Stored: datasets/balanced/propensities\grand_chamber_with_propensity_split_5.pkl
✅ Stored: datasets/balanced/propensities\grand_chamber_with_propensity_split_6.pkl


### Calculate and store nearest_neightbor_labels

In [25]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

def impute_nn_labels_and_store(split_ids, input_dir='datasets/embeddings', label_col='violation', new_col='label_nn'):
    for split_id in tqdm(split_ids, desc="Imputing NN labels"):
        chamber_path = os.path.join(input_dir, f'chamber_split_{split_id}.pkl')
        grand_path = os.path.join(input_dir, f'grand_chamber_split_{split_id}.pkl')

        chamber_df = pd.read_pickle(chamber_path)
        grand_df = pd.read_pickle(grand_path)

        chamber_emb = torch.stack([torch.from_numpy(x) if isinstance(x, np.ndarray) else torch.tensor(x) for x in chamber_df['embedding']])
        grand_emb = torch.stack([torch.from_numpy(x) if isinstance(x, np.ndarray) else torch.tensor(x) for x in grand_df['embedding']])

        chamber_norm = torch.nn.functional.normalize(chamber_emb, dim=1)
        grand_norm = torch.nn.functional.normalize(grand_emb, dim=1)

        sim_matrix = torch.matmul(chamber_norm, grand_norm.T)
        nn_indices = torch.argmax(sim_matrix, dim=1).tolist()

        imputed_labels = [grand_df.iloc[i][label_col] for i in nn_indices]
        chamber_df[new_col] = imputed_labels

        out_path = f'datasets/nearest_neighbor/chamber_split_{split_id}.pkl'
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        chamber_df.to_pickle(out_path)


In [5]:
impute_nn_labels_and_store(split_ids=range(7))

Imputing NN labels: 100%|████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  8.98it/s]


## Extract the expert labels (dissenting opinions)

In [99]:
from openai import OpenAI
from personal_key import my_key # Requires OpenAI key

client = OpenAI(api_key=my_key)

def make_gpt_call(prompt, model="gpt-4.1-mini"):
    return client.chat.completions.create(
      model=model,
      messages=[{"role": "user", "content": prompt}]
    ).choices[0].message.content

def get_votes(judgment_info, label):
    if not isinstance(judgment_info, str): return None
    prompt = '''
    Read the following ending of a court case file and determine how many judges voted for the violation of article 6. Only output the number of judges who voted for. We only care about article 6, ignore the other articles. If there was a unanimous vote for the violation of article 6, that means 7 judges voted for, since there are 7 judges in the panel. Conversely, if there was a unanimous vote against the violation of article 6, that means 0 judges voted for. 

    Final decision: {label} (0=no violation, 1=violation)
    Text: 
    {text}

    How many judges voted for the violation of article 6? Output a single number only. 
    '''.format(label=label, text=judgment_info)
    return make_gpt_call(prompt)

In [113]:
votes = []
for row_id, row in tqdm(chamber_df_train.iterrows()):
    judgment_info = row['judgment_info']
    label = row['violation']
    votes.append(get_votes(judgment_info, label))
chamber_df_train['votes_for'] = votes
balanced_sets_chamber = generate_balanced_subsets(chamber_df_train, n=7, random_seed=42)
for split_id, c_df in enumerate(balanced_sets_chamber):
    c_df.to_csv(f'datasets/votes/chamber_split_{split_id}.csv', index=False)

5031it [36:42,  2.28it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chamber_df_train['votes_for'] = votes


#### Investigate token lengths

In [42]:
from transformers import LongformerTokenizer
import numpy as np

tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

df = chamber_df_train

token_lengths = [len(tokenizer.encode(text, add_special_tokens=True)) for text in df['text']]

max_length = max(token_lengths)
mean_length = np.mean(token_lengths)
median_length = np.median(token_lengths)

print(f"Max token length: {max_length}")
print(f"Mean token length: {mean_length:.1f}")
print(f"Median token length: {median_length}")

thresholds = [512, 1024, 2048, 4096]
for t in thresholds:
    pct = sum(l <= t for l in token_lengths) / len(token_lengths) * 100
    print(f"Percentage of samples with ≤ {t} tokens: {pct:.1f}%")


In [19]:
def head_tail_truncate_ids(text, tokenizer, max_tokens=4094, head_ratio=0.5):
    text = clean_text(text)
    tokens = tokenizer.encode(text, add_special_tokens=False)
    
    if len(tokens) <= max_tokens:
        return tokens

    head_len = int(max_tokens * head_ratio)
    tail_len = max_tokens - head_len
    return tokens[:head_len] + tokens[-tail_len:]

def print_stats(name, texts):
    raw_lengths = [len(tokenizer.encode(t, add_special_tokens=True)) for t in texts]
    truncated_ids = [head_tail_truncate_ids(t, tokenizer) for t in texts]
    truncated_lengths = [len(ids) + 2 for ids in truncated_ids]  # +2 for special tokens

    print(f"===== {name} =====")
    print(f"Total cases: {len(texts)}")

    def summarize(lengths, label):
        print(f"\n-- {label} --")
        print(f"Max token length: {max(lengths)}")
        print(f"Mean token length: {np.mean(lengths):.1f}")
        print(f"Median token length: {np.median(lengths):.1f}")
        for t in [512, 1024, 2048, 4096]:
            pct = sum(l <= t for l in lengths) / len(lengths) * 100
            print(f"≤ {t} tokens: {pct:.1f}%")

    summarize(raw_lengths, "Raw")
    summarize(truncated_lengths, "Preprocessed (Cleaned + Head-Tail)")
    print()

tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

print_stats("Chamber", chamber_df_train['text'])
print_stats("Grand Chamber", grand_chamber_df_train['text'])
print_stats("Both", pd.concat([chamber_df_train['text'], grand_chamber_df_train['text']]))

test
===== Chamber =====
Total cases: 5031

-- Raw --
Max token length: 75456
Mean token length: 2317.9
Median token length: 1303.0
≤ 512 tokens: 20.8%
≤ 1024 tokens: 41.6%
≤ 2048 tokens: 65.4%
≤ 4096 tokens: 85.3%

-- Preprocessed (Cleaned + Head-Tail) --
Max token length: 4096
Mean token length: 1756.0
Median token length: 1303.0
≤ 512 tokens: 20.8%
≤ 1024 tokens: 41.6%
≤ 2048 tokens: 65.4%
≤ 4096 tokens: 100.0%

test
===== Grand Chamber =====
Total cases: 157

-- Raw --
Max token length: 22029
Mean token length: 5638.6
Median token length: 4466.0
≤ 512 tokens: 2.5%
≤ 1024 tokens: 7.6%
≤ 2048 tokens: 17.2%
≤ 4096 tokens: 45.9%

-- Preprocessed (Cleaned + Head-Tail) --
Max token length: 4096
Mean token length: 3275.6
Median token length: 4096.0
≤ 512 tokens: 2.5%
≤ 1024 tokens: 7.6%
≤ 2048 tokens: 17.2%
≤ 4096 tokens: 100.0%

