# Japanese Legal Textual Entailment Solution for COLIEE 2025

This notebook implements a complete solution for the Legal Textual Entailment task from COLIEE 2025 using the **original Japanese dataset** with a multilingual model (XLM-RoBERTa).


## 1. Import Libraries

In [None]:
import os
import re
import xml.etree.ElementTree as ET
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics.pairwise import cosine_similarity
import torch
from torch.utils.data import Dataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

# Import transformers
import transformers
print(f"Using transformers version: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using transformers version: 4.48.3
Using device: cuda


## 2. Upload and Extract Dataset

In [None]:
from google.colab import files
uploaded = files.upload()  # Upload the RAR file with Japanese dataset

# Install unrar if needed
!apt-get install -y unrar

# Extract the dataset
!mkdir -p coileestatute
!unrar x coileestatute.rar

Saving COLIEE2025statute_data-Japanese.zip to COLIEE2025statute_data-Japanese.zip
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
unrar is already the newest version (1:6.1.5-1).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.

UNRAR 6.11 beta 1 freeware      Copyright (c) 1993-2022 Alexander Roshal

Cannot open coileestatute.rar
No such file or directory
No files to extract


## 3. Define Dataset Class for Japanese Text

In [None]:
class JapaneseLegalEntailmentDataset(Dataset):
    """
    Dataset class for Japanese legal entailment training.
    """

    def __init__(self, df, tokenizer, civil_code=None):
        """
        Initialize the dataset.

        Args:
            df: DataFrame containing training data
            tokenizer: Tokenizer for encoding inputs
            civil_code: Dictionary of civil code articles
        """
        self.df = df
        self.tokenizer = tokenizer
        self.civil_code = civil_code

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        query = row['query']
        article_text = row['article_text']
        label = 1 if row['label'] == 'Y' else 0

        # For Japanese text, we don't need to add spaces between tokens
        # XLM-RoBERTa tokenizer will handle this correctly
        encoding = self.tokenizer(
            article_text,
            query,
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        encoding['labels'] = torch.tensor(label)

        return encoding

## 4. Define Main System Class for Japanese Text

In [None]:
class JapaneseLegalEntailmentSystem:
    """
    Main class for the Japanese Legal Textual Entailment System.
    """

    def __init__(self, base_path, model_name="xlm-roberta-base"):
        """
        Initialize the Japanese Legal Entailment System.

        Args:
            base_path: Path to the dataset directory
            model_name: Name of the pre-trained model to use
        """
        self.base_path = base_path
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.retrieval_model = None
        self.entailment_model = None
        self.civil_code = {}
        self.civil_code_embeddings = None
        self.threshold = 0.5  # Default threshold, will be optimized during training

    def load_civil_code(self, file_path):
        """
        Load and parse the Japanese civil code articles.

        Args:
            file_path: Path to the civil code text file
        """
        print("Loading Japanese civil code articles...")
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()

        # Extract articles using regex
        # Japanese pattern for "Article X:" (条文 X:)
        article_pattern = r'第(\d+(?:-\d+)?)条:(.+?)(?=第\d+条:|$)'
        articles = re.findall(article_pattern, content, re.DOTALL)

        # Process and store articles
        for article_num, article_text in articles:
            # Clean article text
            article_text = re.sub(r'\n\d+:', ' ', article_text)  # Replace paragraph numbers
            article_text = re.sub(r'\s+', ' ', article_text).strip()  # Normalize whitespace
            self.civil_code[article_num] = article_text

        print(f"Loaded {len(self.civil_code)} Japanese civil code articles.")

    def load_training_data(self, directory):
        """
        Load and parse the Japanese training data from XML files.

        Args:
            directory: Directory containing training XML files

        Returns:
            DataFrame containing training data
        """
        print("Loading Japanese training data...")
        data = []

        # List all XML files in the directory
        xml_files = [f for f in os.listdir(directory) if f.endswith('.xml')]

        for xml_file in xml_files:
            file_path = os.path.join(directory, xml_file)
            tree = ET.parse(file_path)
            root = tree.getroot()

            for pair in root.findall('pair'):
                pair_id = pair.get('id')
                label = pair.get('label')
                t1 = pair.find('t1').text.strip() if pair.find('t1') is not None else ""
                t2 = pair.find('t2').text.strip() if pair.find('t2') is not None else ""

                # Extract article numbers from t1
                article_nums = self._extract_article_numbers(t1)

                # Add engineered features
                features = self._compute_features(t1, t2)

                data.append({
                    'id': pair_id,
                    'article_text': t1,
                    'query': t2,
                    'label': label,
                    'article_nums': article_nums,
                    **features  # Add all computed features
                })

        df = pd.DataFrame(data)
        print(f"Loaded {len(df)} Japanese training examples.")
        return df

    def _compute_features(self, article_text, query):
        """
        Compute engineered features between Japanese article and query.

        Args:
            article_text: Article text
            query: Query text

        Returns:
            Dictionary of features
        """
        # For Japanese, we use character-level features since words aren't separated by spaces
        article_chars = set(article_text)
        query_chars = set(query)

        # Compute character overlap
        if len(query_chars) > 0:
            char_overlap_ratio = len(article_chars.intersection(query_chars)) / len(query_chars)
        else:
            char_overlap_ratio = 0

        # Compute length features
        article_length = len(article_text)
        query_length = len(query)
        length_ratio = query_length / article_length if article_length > 0 else 0

        # Count specific Japanese legal terms (could be expanded)
        legal_terms = ['法律', '条文', '権利', '義務', '契約']
        term_counts = {}

        for term in legal_terms:
            term_counts[f'article_{term}_count'] = article_text.count(term)
            term_counts[f'query_{term}_count'] = query.count(term)

        return {
            'char_overlap_ratio': char_overlap_ratio,
            'article_length': article_length,
            'query_length': query_length,
            'length_ratio': length_ratio,
            **term_counts
        }

    def _extract_article_numbers(self, text):
        """
        Extract article numbers from Japanese text.

        Args:
            text: Text containing article references

        Returns:
            List of article numbers
        """
        if not text:
            return []

        # Pattern to match "Article X:" in Japanese (第X条)
        pattern = r'第(\d+(?:-\d+)?)条'
        matches = re.findall(pattern, text)
        return matches

    def create_embeddings(self, texts, batch_size=8):
        """
        Create embeddings for a list of Japanese texts.

        Args:
            texts: List of texts to embed
            batch_size: Batch size for processing

        Returns:
            Numpy array of embeddings
        """
        model = AutoModel.from_pretrained(self.model_name).to(device)
        embeddings = []

        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            inputs = self.tokenizer(batch_texts, padding=True, truncation=True,
                                   max_length=512, return_tensors="pt").to(device)

            with torch.no_grad():
                outputs = model(**inputs)

            # Use CLS token embedding as the sentence embedding
            batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(batch_embeddings)

        return np.vstack(embeddings)

    def build_retrieval_index(self):
        """
        Build the retrieval index for Japanese civil code articles.
        """
        print("Building retrieval index for Japanese articles...")
        # Prepare texts for embedding
        article_nums = list(self.civil_code.keys())
        article_texts = [self.civil_code[num] for num in article_nums]

        # Create embeddings
        self.civil_code_embeddings = self.create_embeddings(article_texts)
        self.article_nums = article_nums

        print(f"Created embeddings for {len(article_nums)} Japanese articles.")

    def retrieve_articles(self, query, top_k=5):
        """
        Retrieve relevant Japanese articles for a query.

        Args:
            query: Query text
            top_k: Number of articles to retrieve

        Returns:
            List of (article_num, score) tuples
        """
        # Create query embedding
        query_embedding = self.create_embeddings([query])

        # Calculate similarity scores
        similarity_scores = cosine_similarity(query_embedding, self.civil_code_embeddings)[0]

        # Get top-k articles
        top_indices = np.argsort(-similarity_scores)[:top_k]

        results = []
        for idx in top_indices:
            article_num = self.article_nums[idx]
            score = similarity_scores[idx]
            results.append((article_num, score))

        return results

    def train_entailment_model(self, train_df, val_df=None, epochs=5, batch_size=8):
        """
        Train the entailment model for Japanese text using PyTorch directly with engineered features
        and layer freezing.

        Args:
            train_df: Training data DataFrame
            val_df: Validation data DataFrame
            epochs: Number of training epochs
            batch_size: Training batch size
        """
        print("Training Japanese entailment model...")

        # Calculate class weights to handle imbalance
        y_train = [1 if label == 'Y' else 0 for label in train_df['label']]
        class_counts = np.bincount(y_train)
        print(f"Class distribution in training data: {class_counts[0]} negative, {class_counts[1]} positive")

        # Prepare datasets
        train_dataset = JapaneseLegalEntailmentDataset(train_df, self.tokenizer, self.civil_code)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        if val_df is not None:
            val_dataset = JapaneseLegalEntailmentDataset(val_df, self.tokenizer, self.civil_code)
            val_loader = DataLoader(val_dataset, batch_size=batch_size)

        # Initialize model
        model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=2
        ).to(device)

        # Freeze the first 8 layers of the model
        print("Freezing first 8 layers of the model...")
        for name, param in model.named_parameters():
            if 'roberta.encoder.layer' in name:
                layer_num = int(name.split('.')[3])
                if layer_num < 8:  # Freeze first 8 layers (0-7)
                    param.requires_grad = False

        # Count trainable parameters
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.1%} of total)")

        # Set up optimizer with lower learning rate for Japanese
        optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

        # Use weighted loss function if classes are imbalanced
        if abs(class_counts[0] - class_counts[1]) > 50:  # Arbitrary threshold for imbalance
            weight = torch.tensor([1.0, class_counts[0]/class_counts[1]]).to(device)
            loss_fn = torch.nn.CrossEntropyLoss(weight=weight)
            print(f"Using weighted loss with weights: {weight.cpu().numpy()}")
        else:
            loss_fn = torch.nn.CrossEntropyLoss()
            print("Using standard loss (classes are balanced)")

        # Training loop
        for epoch in range(epochs):
            model.train()
            total_loss = 0

            for batch in train_loader:
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}
                labels = batch.pop('labels')

                # Forward pass
                outputs = model(**batch)
                logits = outputs.logits

                # Calculate loss
                loss = loss_fn(logits, labels)

                # Backward pass and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

            # Validation
            if val_df is not None:
                model.eval()
                val_preds = []
                val_labels = []
                val_probs = []

                with torch.no_grad():
                    for batch in val_loader:
                        batch = {k: v.to(device) for k, v in batch.items()}
                        labels = batch.pop('labels')

                        outputs = model(**batch)
                        logits = outputs.logits
                        probs = torch.softmax(logits, dim=1)[:, 1]  # Probability of positive class
                        preds = torch.argmax(logits, dim=1)

                        val_preds.extend(preds.cpu().numpy())
                        val_labels.extend(labels.cpu().numpy())
                        val_probs.extend(probs.cpu().numpy())

                # Calculate metrics with default threshold (0.5)
                accuracy = accuracy_score(val_labels, val_preds)
                f1 = f1_score(val_labels, val_preds)
                precision = precision_score(val_labels, val_preds)
                recall = recall_score(val_labels, val_preds)

                print(f"Validation Metrics (threshold=0.5):")
                print(f"  Accuracy: {accuracy:.4f}")
                print(f"  F1 Score: {f1:.4f}")
                print(f"  Precision: {precision:.4f}")
                print(f"  Recall: {recall:.4f}")

                # Find best threshold
                best_f1 = 0
                best_threshold = 0.5
                thresholds = np.arange(0.3, 0.8, 0.05)

                for threshold in thresholds:
                    threshold_preds = [1 if p > threshold else 0 for p in val_probs]
                    threshold_f1 = f1_score(val_labels, threshold_preds)
                    if threshold_f1 > best_f1:
                        best_f1 = threshold_f1
                        best_threshold = threshold

                # Calculate metrics with best threshold
                best_preds = [1 if p > best_threshold else 0 for p in val_probs]
                best_accuracy = accuracy_score(val_labels, best_preds)
                best_precision = precision_score(val_labels, best_preds)
                best_recall = recall_score(val_labels, best_preds)

                print(f"Best threshold: {best_threshold:.2f}")
                print(f"Validation Metrics (threshold={best_threshold:.2f}):")
                print(f"  Accuracy: {best_accuracy:.4f}")
                print(f"  F1 Score: {best_f1:.4f}")
                print(f"  Precision: {best_precision:.4f}")
                print(f"  Recall: {best_recall:.4f}")

                # Save best threshold for prediction
                self.threshold = best_threshold

        self.entailment_model = model
        print("Japanese entailment model training completed.")

    def predict_entailment(self, query, retrieved_articles):
        """
        Predict entailment for a Japanese query and retrieved articles.

        Args:
            query: Query text
            retrieved_articles: List of (article_num, score) tuples

        Returns:
            Entailment prediction (True/False)
        """
        if not retrieved_articles:
            return False

        # Prepare inputs
        inputs = []
        for article_num, _ in retrieved_articles:
            if article_num in self.civil_code:
                article_text = self.civil_code[article_num]
                inputs.append((article_text, query))

        if not inputs:
            return False

        # Tokenize inputs
        encoded_inputs = []
        for article_text, query_text in inputs:
            encoded = self.tokenizer(
                article_text,
                query_text,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            encoded_inputs.append(encoded)

        # Make predictions
        entailment_scores = []
        for encoded in encoded_inputs:
            encoded = {k: v.to(device) for k, v in encoded.items()}
            with torch.no_grad():
                outputs = self.entailment_model(**encoded)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=1)
                entailment_score = probabilities[0, 1].item()  # Probability of entailment (class 1)
                entailment_scores.append(entailment_score)

        # Aggregate scores (using max as a simple strategy)
        max_score = max(entailment_scores)

        # Use optimized threshold
        return max_score > self.threshold

    def process_test_data(self, test_file, output_task3, output_task4, system_id="SYSTEM"):
        """
        Process Japanese test data and generate output files for Task 3 and Task 4.

        Args:
            test_file: Path to test XML file
            output_task3: Path to Task 3 output file
            output_task4: Path to Task 4 output file
            system_id: System identifier for output files
        """
        print(f"Processing Japanese test data from {test_file}...")

        # Parse test file
        tree = ET.parse(test_file)
        root = tree.getroot()

        task3_results = []
        task4_results = []

        for pair in root.findall('pair'):
            pair_id = pair.get('id')
            query = pair.find('t2').text.strip() if pair.find('t2') is not None else ""

            # Task 3: Retrieve relevant articles
            retrieved_articles = self.retrieve_articles(query, top_k=5)

            # Write Task 3 results
            for rank, (article_num, score) in enumerate(retrieved_articles, 1):
                task3_line = f"{pair_id} Q0 {article_num} {rank} {score:.6f} {system_id}"
                task3_results.append(task3_line)

            # Task 4: Predict entailment
            is_entailed = self.predict_entailment(query, retrieved_articles)
            task4_line = f"{pair_id} {'Y' if is_entailed else 'N'} {system_id}"
            task4_results.append(task4_line)

        # Write output files
        with open(output_task3, 'w', encoding='utf-8') as f:
            f.write('\n'.join(task3_results))

        with open(output_task4, 'w', encoding='utf-8') as f:
            f.write('\n'.join(task4_results))

        print(f"Task 3 results written to {output_task3}")
        print(f"Task 4 results written to {output_task4}")

## 5. Run the System on Japanese Data

In [None]:
# Set paths for Japanese dataset
base_path = "/content/coileestatute/COLIEE2025statute_data"
civil_code_path = os.path.join(base_path, "text/civil_code_jp-1to724.txt")
train_dir = os.path.join(base_path, "train")

# Initialize system with XLM-RoBERTa (multilingual model)
system = JapaneseLegalEntailmentSystem(base_path, model_name="xlm-roberta-base")

# Load Japanese civil code
system.load_civil_code(civil_code_path)

# Load Japanese training data
train_df = system.load_training_data(train_dir)

# Split data for training and validation
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)

# Build retrieval index
system.build_retrieval_index()

# Train entailment model
system.train_entailment_model(train_df, val_df, epochs=5)

## 6. Process Test Data and Generate Results

In [None]:
# For demonstration, we'll use one of the training files as a mock test file
test_file = os.path.join(train_dir, "riteval_R05.xml")
output_task3 = "task3.YOURID"
output_task4 = "task4.YOURID"
system.process_test_data(test_file, output_task3, output_task4, system_id="YOURID")

# Download the result files
from google.colab import files
files.download(output_task3)
files.download(output_task4)

## 7. Evaluate Performance

In [None]:
# Let's evaluate on the validation set
val_results = []
for _, row in val_df.iterrows():
    query = row['query']
    true_label = 1 if row['label'] == 'Y' else 0

    # Retrieve articles
    retrieved_articles = system.retrieve_articles(query, top_k=5)

    # Predict entailment
    is_entailed = system.predict_entailment(query, retrieved_articles)
    pred_label = 1 if is_entailed else 0

    val_results.append((true_label, pred_label))

# Calculate metrics
true_labels, pred_labels = zip(*val_results)
accuracy = accuracy_score(true_labels, pred_labels)
f1 = f1_score(true_labels, pred_labels)
precision = precision_score(true_labels, pred_labels)
recall = recall_score(true_labels, pred_labels)

print(f"Validation Results:")
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

## 8. Comparison with English Model (Optional)

If you have results from both the English and Japanese models, you can compare them here to see which performs better.

In [None]:
# Example comparison code (replace with your actual results)
english_metrics = {
    'accuracy': 0.4875,
    'f1': 0.6555,
    'precision': 0.4875,
    'recall': 1.0000
}

japanese_metrics = {
    'accuracy': accuracy,
    'f1': f1,
    'precision': precision,
    'recall': recall
}

# Create comparison table
comparison_df = pd.DataFrame({
    'Metric': ['Accuracy', 'F1 Score', 'Precision', 'Recall'],
    'English Model': [english_metrics['accuracy'], english_metrics['f1'],
                     english_metrics['precision'], english_metrics['recall']],
    'Japanese Model': [japanese_metrics['accuracy'], japanese_metrics['f1'],
                      japanese_metrics['precision'], japanese_metrics['recall']],
    'Difference': [japanese_metrics['accuracy'] - english_metrics['accuracy'],
                  japanese_metrics['f1'] - english_metrics['f1'],
                  japanese_metrics['precision'] - english_metrics['precision'],
                  japanese_metrics['recall'] - english_metrics['recall']]
})

comparison_df