# 1. Load libraries

In [71]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from datapreparation import IGTClassificationDataset

labeled_data_path = "labeled_data.json"
model_name = "nlpaueb/legal-bert-base-uncased"
general_output_path = "./results"
model_output_path = "./igtclassification-model"
label_map_fname = "label_map.json"

# 2. Load labeled data for training

In [72]:
def load_data_from_json(json_file_path):
    # Lists to store all examples
    all_tokens = []
    all_labels = []
    
    # Read and parse the JSON file
    with open(json_file_path, 'r') as file:
        data = json.load(file)
    
    # Iterate through each entry in the JSON array
    for entry in data:
        # Extract tokens and labels for this entry
        tokens = entry["tokens"]
        labels = entry["labels"]
        
        # Add to our collected lists
        all_tokens.append(tokens)
        all_labels.append(labels)
    
    return all_tokens, all_labels

json_file_path = labeled_data_path
tokens_list, labels_list = load_data_from_json(json_file_path)

# 3. Split data into train and validation sets

In [73]:
train_tokens, val_tokens, train_labels, val_labels = train_test_split(
    tokens_list, labels_list, test_size=0.3, random_state=42
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
train_dataset = IGTClassificationDataset(train_tokens, train_labels, tokenizer)
val_dataset = IGTClassificationDataset(val_tokens, val_labels, tokenizer)

# 4. Initialize pretrained model (LegalBERT)

In [74]:
# hide warnings
from transformers.utils import logging
logging.set_verbosity_error()

num_labels = len(train_dataset.get_label_map())
model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=num_labels
)

# 5. Train and save fine-tuned model

In [75]:
# Define training arguments
training_args = TrainingArguments(
    output_dir=general_output_path,
    eval_strategy="epoch",
    learning_rate=2e-7,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=13,
    weight_decay=0.05,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# Train the model
trainer.train()

# Save model and tokenizer in a clean format
trainer.model.save_pretrained(model_output_path)
tokenizer.save_pretrained(model_output_path)

# Save the label mapping for inference
with open(f"{model_output_path}/{label_map_fname}", "w") as f:
    json.dump(train_dataset.get_label_map(), f)

Tokens: 34, Labels: 33
{'eval_loss': 2.6780357360839844, 'eval_runtime': 0.1605, 'eval_samples_per_second': 49.858, 'eval_steps_per_second': 6.232, 'epoch': 1.0}
Tokens: 34, Labels: 33
{'eval_loss': 2.678057909011841, 'eval_runtime': 0.0745, 'eval_samples_per_second': 107.425, 'eval_steps_per_second': 13.428, 'epoch': 2.0}
Tokens: 34, Labels: 33
{'eval_loss': 2.678065299987793, 'eval_runtime': 0.0726, 'eval_samples_per_second': 110.202, 'eval_steps_per_second': 13.775, 'epoch': 3.0}
Tokens: 34, Labels: 33
{'eval_loss': 2.6780707836151123, 'eval_runtime': 0.0726, 'eval_samples_per_second': 110.14, 'eval_steps_per_second': 13.767, 'epoch': 4.0}
Tokens: 34, Labels: 33
{'eval_loss': 2.6780970096588135, 'eval_runtime': 0.0733, 'eval_samples_per_second': 109.097, 'eval_steps_per_second': 13.637, 'epoch': 5.0}
Tokens: 34, Labels: 33
{'eval_loss': 2.6780989170074463, 'eval_runtime': 0.0746, 'eval_samples_per_second': 107.2, 'eval_steps_per_second': 13.4, 'epoch': 6.0}
Tokens: 34, Labels: 33
{'

# 6. Load model from file

In [76]:
def load_model(model_path):
    # Load model and tokenizer
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Load label mapping
    with open(f"{model_path}/{label_map_fname}", "r") as f:
        label_map = json.load(f)
        
    # Create reverse mapping (id to label)
    id_to_label = {int(idx): label for label, idx in label_map.items()}
    
    return model, tokenizer, id_to_label

# 7. Predict function

In [77]:
def predict_sequence_tags(text, model, tokenizer, id_to_label):
    """
    Predict sequence tags for input text using a fine-tuned model.
    
    Args:
        text (str): Input text to be tagged
        model: Fine-tuned token classification model
        tokenizer: Tokenizer corresponding to the model
        id_to_label (dict): Mapping from label IDs to label names
        
    Returns:
        list: List of (token, label) tuples
    """
    # Use the tokenizer to split into tokens first
    nlp_tokens = text.split()  # We use simple splitting first to get base tokens
    
    # Then get BERT tokens and word_ids mapping
    encoding = tokenizer(
        nlp_tokens,
        is_split_into_words=True,
        return_tensors="pt",
        padding=True,
        truncation=True
    )
    
    # Get predictions
    with torch.no_grad():
        outputs = model(**encoding)
        predictions = torch.argmax(outputs.logits, dim=-1)
    
    # Map predictions back to original tokens
    predicted_labels = []
    previous_word_idx = None
    
    # Use word_ids to map predictions back to original tokens
    word_ids = encoding.word_ids(batch_index=0)
    
    for idx, word_idx in enumerate(word_ids):
        # Skip special tokens and tokens that we've already assigned labels to
        if word_idx is None or word_idx == previous_word_idx:
            continue
            
        # Get prediction for this token (use the first subword's prediction)
        pred_id = predictions[0, idx].item()
        
        # Map prediction ID to label name
        if pred_id in id_to_label:
            label = id_to_label[pred_id]
        else:
            label = "O"  # Default to Outside tag
            
        # Add prediction
        predicted_labels.append(label)
        previous_word_idx = word_idx
    
    # Return tokens with their predicted labels
    return list(zip(nlp_tokens, predicted_labels))

# 8. Prediction example

In [78]:
# Example usage
model, tokenizer, id_to_label = load_model(model_output_path)

# Test it on a sample sentence
test_sentence = "If a pandemic happens, the WHO should administer vaccines before 30 days has expired, failing which they will be imposed a fine of twenty million euros."
results = predict_sequence_tags(test_sentence, model, tokenizer, id_to_label)

# Print results in a readable format
print("Token\tLabel")
print("-" * 30)
for token, label in results:
    print(f"{token}\t{label}")


Token	Label
------------------------------
If	B-Or_Else
a	I-Activation_Condition
pandemic	B-Aim
happens,	B-Aim
the	B-Aim
WHO	B-Or_Else
should	I-Activation_Condition
administer	I-Activation_Condition
vaccines	B-Aim
before	B-Aim
30	I-Or_Else
days	B-Activation_Condition
has	I-Activation_Condition
expired,	I-Or_Else
failing	I-Execution_Constraint
which	B-Aim
they	B-Or_Else
will	I-Activation_Condition
be	I-Execution_Constraint
imposed	B-Aim
a	I-Activation_Condition
fine	I-Execution_Constraint
of	B-Aim
twenty	I-Execution_Constraint
million	B-Aim
euros.	B-Aim


# 9. Visualise prediction function

In [79]:
from IPython.display import display, HTML

def print_colored_results(results):
    """Print results with color-coding using HTML for Jupyter notebooks and include a legend."""
    
    color_map = {
        "B-Activation_Condition": "red",
        "I-Activation_Condition": "red",
        "B-Attribute": "blue",
        "I-Attribute": "blue",
        "B-Deontic": "green",
        "I-Deontic": "green",
        "B-Aim": "orange",
        "I-Aim": "orange",
        "B-Object": "magenta",
        "I-Object": "magenta",
        "B-Execution_Constraint": "cyan",
        "I-Execution_Constraint": "cyan",
        "B-Or_Else": "lightcoral",
        "I-Or_Else": "lightcoral",
        "O": "black"
    }

    # Generate the legend
    legend_html = "<p><b>Legend:</b></p><table style='border-collapse: collapse;'>"
    for tag, color in color_map.items():
        legend_html += f"""
        <tr>
            <td style='background-color: {color}; width: 20px; height: 20px;'></td>
            <td style='padding-left: 10px;'>{tag}</td>
        </tr>
        """
    legend_html += "</table><br>"

    # Generate the colored text output
    text_html = ""
    for token, label in results:
        color = color_map.get(label, "black")
        text_html += f'<span style="color: {color}; font-weight: bold;">{token}</span> '

    # Display legend and text
    display(HTML(legend_html + text_html))

# 10. Visualisation example

In [80]:
# Visual output
print_colored_results(results)

0,1
,B-Activation_Condition
,I-Activation_Condition
,B-Attribute
,I-Attribute
,B-Deontic
,I-Deontic
,B-Aim
,I-Aim
,B-Object
,I-Object
