# Hybrid Classification Pipeline

In [15]:
import pandas as pd
import re
import nltk
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import wordnet, stopwords
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from transformers import RobertaTokenizer, RobertaModel, Trainer, TrainingArguments, EarlyStoppingCallback
from torch.utils.data import Dataset
import warnings
warnings.filterwarnings('ignore')
os.environ["WANDB_DISABLED"] = "true"

nltk_data_dir = "./nltk_data"
nltk.data.path.append(nltk_data_dir)
os.makedirs(nltk_data_dir, exist_ok=True)

nltk.download('punkt', download_dir=nltk_data_dir, quiet=True)
nltk.download('punkt_tab', download_dir=nltk_data_dir, quiet=True)
nltk.download('wordnet', download_dir=nltk_data_dir, quiet=True)
nltk.download('omw-1.4', download_dir=nltk_data_dir, quiet=True)
nltk.download('averaged_perceptron_tagger_eng', download_dir=nltk_data_dir, quiet=True)
nltk.download('stopwords', download_dir=nltk_data_dir, quiet=True)

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Silicon GPU (MPS)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using NVIDIA GPU (CUDA)")
else:
    device = torch.device("cpu")
    print("Using CPU")

print(f"Device: {device}")

Using Apple Silicon GPU (MPS)
Device: mps


In [16]:
input_file = "../dataset/OBC_Cleaned.csv"
df = pd.read_csv(input_file)
df = df.dropna(how="all")

print(f"Loaded {len(df)} rows")
print(f"\nColumns: {list(df.columns)}")
print(f"\nVerdict distribution:\n{df['Verdict'].value_counts()}")
print(f"\nOffence distribution:\n{df['Offence'].value_counts()}")

Loaded 43389 rows

Columns: ['Trial_ID', 'Date', 'Defendant_Gender', 'Num_Defendants', 'Victim_Gender', 'Num_Victims', 'Offence', 'Offence_Subcategory', 'Verdict', 'Text_Length', 'Year', 'Trial_Text']

Verdict distribution:
Verdict
guilty       31253
notGuilty    12136
Name: count, dtype: int64

Offence distribution:
Offence
theft            35684
violentTheft      2410
deception         1436
breakingPeace     1204
sexual            1069
kill              1050
royalOffences      234
miscellaneous      152
damage             150
Name: count, dtype: int64


In [17]:
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

def get_wordnet_pos(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN

def clean_text_advanced(text):
    try:
        if not isinstance(text, str):
            return ""
        text = re.sub(r'^\s*[TM]\.\s*', '', text)
        text = re.sub(r'[^a-zA-Z\s]', '', text)
        text = re.sub(r'\b\w{1,2}\b', '', text)

        sentences = sent_tokenize(text)
        cleaned_sentences = []

        for sentence in sentences:
            words = word_tokenize(sentence.lower())
            words = [word for word in words if word not in stop_words]
            tagged_words = pos_tag(words)
            lemmatized_words = [
                lemmatizer.lemmatize(word, get_wordnet_pos(tag))
                for word, tag in tagged_words
            ]
            cleaned_sentences.append(" ".join(lemmatized_words))

        return " ".join(cleaned_sentences)
    except Exception as e:
        print(f"[ERROR] {e}")
        return text

def clean_text_basic(text):
    text = str(text)
    text = re.sub(r'^[\W\d\s]+', '', text)
    text = re.sub(r'[+*FO]+', '', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

df['Trial_Text'] = df['Trial_Text'].astype(str).apply(clean_text_advanced)
df['Trial_Text'] = df['Trial_Text'].apply(clean_text_basic)

In [18]:
df = df[df['Verdict'].isin(['guilty', 'notGuilty'])].copy()
df = df[df['Offence'] != 'Unknown'].copy()
df = df.drop_duplicates(subset="Trial_ID", keep="first")
df = df.drop_duplicates(subset="Trial_Text", keep="first")

df_not_guilty = df[df['Verdict'] == 'notGuilty']
df_guilty = df[df['Verdict'] == 'guilty']
min_count = min(len(df_not_guilty), len(df_guilty))

print(f"Before balancing: notGuilty={len(df_not_guilty)}, guilty={len(df_guilty)}")

df_not_guilty_balanced = df_not_guilty.sample(n=min_count, random_state=42)
df_guilty_balanced = df_guilty.sample(n=min_count, random_state=42)
df = pd.concat([df_not_guilty_balanced, df_guilty_balanced]).reset_index(drop=True)

print(f"After balancing: {len(df)} total rows")
print(f"\n{df['Verdict'].value_counts()}")

Before balancing: notGuilty=12136, guilty=31252
After balancing: 24272 total rows

Verdict
notGuilty    12136
guilty       12136
Name: count, dtype: int64


In [19]:
label_map = {"guilty": 1, "notGuilty": 0}
df["Label"] = df["Verdict"].map(label_map)

offence_encoder = LabelEncoder()
df["Offence_Encoded"] = offence_encoder.fit_transform(df["Offence"])

offence_sub_encoder = LabelEncoder()
df["Offence_Subcategory_Encoded"] = offence_sub_encoder.fit_transform(df["Offence_Subcategory"])

def_gender_encoder = LabelEncoder()
df["Defendant_Gender_Encoded"] = def_gender_encoder.fit_transform(df["Defendant_Gender"])

vic_gender_encoder = LabelEncoder()
df["Victim_Gender_Encoded"] = vic_gender_encoder.fit_transform(df["Victim_Gender"])

scaler = StandardScaler()
df[["Year_Scaled", "Num_Defendants_Scaled", "Num_Victims_Scaled", "Text_Length_Scaled"]] = scaler.fit_transform(
    df[["Year", "Num_Defendants", "Num_Victims", "Text_Length"]]
)

print("="*60)
print("METADATA FEATURES SUMMARY")
print("="*60)
print(f"\nCategorical (one-hot encoded):")
print(f"  - Offence: {len(offence_encoder.classes_)} categories")
print(f"  - Offence_Subcategory: {len(offence_sub_encoder.classes_)} categories")
print(f"  - Defendant_Gender: {len(def_gender_encoder.classes_)} categories")
print(f"  - Victim_Gender: {len(vic_gender_encoder.classes_)} categories")

print(f"\nNumerical (scaled):")
print(f"  - Year: [{df['Year'].min()}, {df['Year'].max()}]")
print(f"  - Num_Defendants: [{df['Num_Defendants'].min()}, {df['Num_Defendants'].max()}]")
print(f"  - Num_Victims: [{df['Num_Victims'].min()}, {df['Num_Victims'].max()}]")
print(f"  - Text_Length: [{df['Text_Length'].min()}, {df['Text_Length'].max()}]")

total_metadata = 4 + len(offence_encoder.classes_) + len(offence_sub_encoder.classes_) + len(def_gender_encoder.classes_) + len(vic_gender_encoder.classes_)
print(f"\nTOTAL METADATA FEATURES: {total_metadata}")
print(f"   (4 numerical + {len(offence_encoder.classes_)} + {len(offence_sub_encoder.classes_)} + {len(def_gender_encoder.classes_)} + {len(vic_gender_encoder.classes_)} categorical)")

METADATA FEATURES SUMMARY

Categorical (one-hot encoded):
  - Offence: 9 categories
  - Offence_Subcategory: 49 categories
  - Defendant_Gender: 3 categories
  - Victim_Gender: 3 categories

Numerical (scaled):
  - Year: [1720, 1913]
  - Num_Defendants: [1, 35]
  - Num_Victims: [1, 15]
  - Text_Length: [122, 321242]

TOTAL METADATA FEATURES: 68
   (4 numerical + 9 + 49 + 3 + 3 categorical)


In [20]:
class HybridCourtDataset(Dataset):
    def __init__(self, texts, labels, metadata_dict, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.metadata_dict = metadata_dict
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )

        item = {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

        numerical = torch.tensor([
            self.metadata_dict['year_scaled'][idx],
            self.metadata_dict['num_defendants_scaled'][idx],
            self.metadata_dict['num_victims_scaled'][idx],
            self.metadata_dict['text_length_scaled'][idx]
        ], dtype=torch.float)

        offence_onehot = torch.zeros(self.metadata_dict['num_offence_classes'], dtype=torch.float)
        offence_onehot[self.metadata_dict['offences_encoded'][idx]] = 1.0

        offence_sub_onehot = torch.zeros(self.metadata_dict['num_offence_sub_classes'], dtype=torch.float)
        offence_sub_onehot[self.metadata_dict['offences_sub_encoded'][idx]] = 1.0

        def_gender_onehot = torch.zeros(self.metadata_dict['num_def_gender_classes'], dtype=torch.float)
        def_gender_onehot[self.metadata_dict['def_gender_encoded'][idx]] = 1.0

        vic_gender_onehot = torch.zeros(self.metadata_dict['num_vic_gender_classes'], dtype=torch.float)
        vic_gender_onehot[self.metadata_dict['vic_gender_encoded'][idx]] = 1.0

        item["metadata"] = torch.cat([numerical, offence_onehot, offence_sub_onehot, def_gender_onehot, vic_gender_onehot])
        return item

    def __len__(self):
        return len(self.labels)

In [21]:
class HybridRobertaClassifier(nn.Module):
    def __init__(self, num_metadata_features, num_labels=2, dropout=0.3):
        super(HybridRobertaClassifier, self).__init__()

        self.roberta = RobertaModel.from_pretrained("roberta-base")
        self.text_hidden_size = self.roberta.config.hidden_size

        self.metadata_encoder = nn.Sequential(
            nn.Linear(num_metadata_features, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.ReLU(),
        )

        combined_size = self.text_hidden_size + 32
        self.fusion = nn.Sequential(
            nn.Linear(combined_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        self.classifier = nn.Linear(64, num_labels)

    def forward(self, input_ids, attention_mask, metadata, labels=None):
        roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        text_features = roberta_output.last_hidden_state[:, 0, :]

        metadata_features = self.metadata_encoder(metadata)

        combined = torch.cat([text_features, metadata_features], dim=1)
        fused = self.fusion(combined)

        logits = self.classifier(fused)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}

In [22]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    return {"accuracy": acc, "f1": f1}

In [23]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

texts = df["Trial_Text"].tolist()
labels = df["Label"].tolist()

metadata_dict = {
    'year_scaled': df["Year_Scaled"].tolist(),
    'num_defendants_scaled': df["Num_Defendants_Scaled"].tolist(),
    'num_victims_scaled': df["Num_Victims_Scaled"].tolist(),
    'text_length_scaled': df["Text_Length_Scaled"].tolist(),
    'offences_encoded': df["Offence_Encoded"].tolist(),
    'offences_sub_encoded': df["Offence_Subcategory_Encoded"].tolist(),
    'def_gender_encoded': df["Defendant_Gender_Encoded"].tolist(),
    'vic_gender_encoded': df["Victim_Gender_Encoded"].tolist(),
    'num_offence_classes': len(offence_encoder.classes_),
    'num_offence_sub_classes': len(offence_sub_encoder.classes_),
    'num_def_gender_classes': len(def_gender_encoder.classes_),
    'num_vic_gender_classes': len(vic_gender_encoder.classes_),
}

num_metadata_features = (
    4 +
    len(offence_encoder.classes_) +
    len(offence_sub_encoder.classes_) +
    len(def_gender_encoder.classes_) +
    len(vic_gender_encoder.classes_)
)

print(f"Total samples: {len(texts)}")
print(f"Total metadata features: {num_metadata_features}")
print(f"  - 4 numerical (year, num_defendants, num_victims, text_length)")
print(f"  - {len(offence_encoder.classes_)} offence categories")
print(f"  - {len(offence_sub_encoder.classes_)} offence subcategories")
print(f"  - {len(def_gender_encoder.classes_)} defendant genders")
print(f"  - {len(vic_gender_encoder.classes_)} victim genders")

Total samples: 24272
Total metadata features: 68
  - 4 numerical (year, num_defendants, num_victims, text_length)
  - 9 offence categories
  - 49 offence subcategories
  - 3 defendant genders
  - 3 victim genders


In [24]:
print("="*80)
print("HYBRID MODEL DATA SAMPLE")
print("="*80)

idx = 0
sample_text = texts[idx]
sample_label = labels[idx]

print(f"\nTEXT (goes to RoBERTa):")
print(f"   {sample_text[:200]}...")
print(f"   [Length: {len(sample_text)} characters]")

print(f"\nMETADATA FEATURES:")
print(f"   Total: {num_metadata_features} features")
print()

print(f"   Numerical (4):")
print(f"     Year (scaled): {metadata_dict['year_scaled'][idx]:.4f}")
print(f"     Num Defendants (scaled): {metadata_dict['num_defendants_scaled'][idx]:.4f}")
print(f"     Num Victims (scaled): {metadata_dict['num_victims_scaled'][idx]:.4f}")
print(f"     Text Length (scaled): {metadata_dict['text_length_scaled'][idx]:.4f}")
print()

print(f"   Categorical (one-hot encoded):")
print(f"     Offence: {offence_encoder.classes_[metadata_dict['offences_encoded'][idx]]}")
print(f"     Offence Subcategory: {offence_sub_encoder.classes_[metadata_dict['offences_sub_encoded'][idx]]}")
print(f"     Defendant Gender: {def_gender_encoder.classes_[metadata_dict['def_gender_encoded'][idx]]}")
print(f"     Victim Gender: {vic_gender_encoder.classes_[metadata_dict['vic_gender_encoded'][idx]]}")

print(f"\nLABEL: {sample_label} ({'guilty' if sample_label == 1 else 'notGuilty'})")

print("\n" + "="*80)
print("vs STANDARD MODEL: embeds Year in text like 'Year: 1720 Text: ...'")
print("="*80)

HYBRID MODEL DATA SAMPLE

TEXT (goes to RoBERTa):
   jhn ellit indict steal live cock value three live hen value property john dunn april evidence affect prisoner guilty try irst middlesex jury baron htham...
   [Length: 152 characters]

METADATA FEATURES:
   Total: 68 features

   Numerical (4):
     Year (scaled): -0.4266
     Num Defendants (scaled): -0.3864
     Num Victims (scaled): -0.3165
     Text Length (scaled): -0.4009

   Categorical (one-hot encoded):
     Offence: theft
     Offence Subcategory: animalTheft
     Defendant Gender: male
     Victim Gender: male

LABEL: 0 (notGuilty)

vs STANDARD MODEL: embeds Year in text like 'Year: 1720 Text: ...'


In [25]:
from sklearn.model_selection import train_test_split

train_idx, val_idx = train_test_split(
    range(len(texts)), test_size=0.2, random_state=42, stratify=labels
)

train_texts = [texts[i] for i in train_idx]
val_texts = [texts[i] for i in val_idx]
train_labels = [labels[i] for i in train_idx]
val_labels = [labels[i] for i in val_idx]

train_metadata = {
    key: [val[i] for i in train_idx] if isinstance(val, list) else val
    for key, val in metadata_dict.items()
}

val_metadata = {
    key: [val[i] for i in val_idx] if isinstance(val, list) else val
    for key, val in metadata_dict.items()
}

print(f"Training samples: {len(train_texts)}")
print(f"Validation samples: {len(val_texts)}")

Training samples: 19417
Validation samples: 4855


In [26]:
train_dataset = HybridCourtDataset(train_texts, train_labels, train_metadata, tokenizer)
val_dataset = HybridCourtDataset(val_texts, val_labels, val_metadata, tokenizer)

model = HybridRobertaClassifier(num_metadata_features=num_metadata_features)

training_args = TrainingArguments(
    output_dir="./temp_hybrid_model",
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    learning_rate=2e-5,
    report_to="none",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    warmup_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    use_cpu=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

print("Starting training...")
trainer.train()

print("\nEvaluating model...")
metrics = trainer.evaluate()

logs = trainer.state.log_history
train_loss = [log["loss"] for log in logs if "loss" in log]
val_loss = [log["eval_loss"] for log in logs if "eval_loss" in log]
val_acc = [log["eval_accuracy"] for log in logs if "eval_accuracy" in log]

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(train_loss, label="Train Loss")
plt.plot(val_loss, label="Val Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_acc, label="Val Accuracy")
plt.title("Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.tight_layout()
plt.show()

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training...


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.6629,0.396661,0.877034,0.876379
2,0.3063,0.208231,0.91102,0.910922


KeyboardInterrupt: 

In [None]:
print("\n" + "="*60)
print("HYBRID MODEL - TRAINING RESULTS")
print("="*60)
print(f"\nValidation Accuracy: {metrics['eval_accuracy']:.4f}")
print(f"Validation F1 Score: {metrics['eval_f1']:.4f}")
print(f"Validation Loss: {metrics['eval_loss']:.4f}")

In [None]:
final_model_path = "./hybrid_roberta_model"
os.makedirs(final_model_path, exist_ok=True)

torch.save(model.state_dict(), os.path.join(final_model_path, "model_state.pt"))

tokenizer.save_pretrained(final_model_path)

import pickle
with open(os.path.join(final_model_path, "offence_encoder.pkl"), "wb") as f:
    pickle.dump(offence_encoder, f)
with open(os.path.join(final_model_path, "year_scaler.pkl"), "wb") as f:
    pickle.dump(year_scaler, f)

print(f"\nModel saved to {final_model_path}")
print(f"  - model_state.pt")
print(f"  - tokenizer files")
print(f"  - offence_encoder.pkl")
print(f"  - year_scaler.pkl")