<a href="https://colab.research.google.com/github/Thoran37/Multi-Agent-Law-Framework/blob/main/dataset/LaCour!_Preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import xml.etree.ElementTree as ET

In [3]:
import os
from lxml import etree
import pandas as pd

DATA_DIR = "/content/LaCour/"

def normalize_role(role):
    if not role:
        return "OTHER"
    role = role.lower()
    if any(r in role for r in ["applicant", "counsel", "lawyer1"]):
        return "LAWYER1"
    if any(r in role for r in ["government", "agent", "lawyer2"]):
        return "LAWYER2"
    if any(r in role for r in ["judge", "president"]):
        return "JUDGE"
    return "OTHER"

def extract_case_id(filename):
    base = os.path.splitext(filename)[0]
    parts = base.split("_", 1)
    return parts[1] if len(parts) > 1 else base

rows = []

for filename in os.listdir(DATA_DIR):
    if filename.endswith(".xml"):
        file_path = os.path.join(DATA_DIR, filename)
        case_id = extract_case_id(filename)

        try:
            tree = etree.parse(file_path)
        except Exception as e:
            print("Error parsing:", filename, e)
            continue

        segments = tree.xpath("//SpeakerSegment/Segment")

        turn = 0
        role_counter = {}
        previous_role = "NONE"  # <-- initialize previous speaker role

        for seg in segments:
            lang = seg.findtext("meta_data/Language")
            if lang != "en":
                continue

            role = normalize_role(seg.findtext("meta_data/Role"))
            if role == "OTHER":
                continue

            text = seg.findtext("text")
            if not text:
                continue
            text = text.strip()

            start = seg.findtext("meta_data/TimestampBegin")
            end = seg.findtext("meta_data/TimestampEnd")
            start = float(start) if start else None
            end = float(end) if end else None
            duration = (end - start) if (start is not None and end is not None) else None

            role_counter[role] = role_counter.get(role, 0) + 1
            role_turn = role_counter[role]

            rows.append({
                "case_id": case_id,
                "speaker_role": role,
                "prev_speaker_role": previous_role,   # <-- store previous role
                "start_time": start,
                "end_time": end,
                "duration": duration,
                "text": text,
                "turn": turn,
                "role_turn": role_turn,
                "file": filename
            })

            previous_role = role  # <-- update previous role
            turn += 1

df = pd.DataFrame(rows)
df.to_csv("LaCour_preprocessed.csv", index=False)

df.head(), len(df)

(            case_id speaker_role prev_speaker_role  start_time  end_time  \
 0  4887608_07032012        JUDGE              NONE       52.38     53.38   
 1  4887608_07032012        JUDGE             JUDGE       61.14     70.76   
 2  4887608_07032012        JUDGE             JUDGE       70.76     85.58   
 3  4887608_07032012        JUDGE             JUDGE       85.58     93.98   
 4  4887608_07032012        JUDGE             JUDGE       93.98    100.74   
 
    duration                                               text  turn  \
 0      1.00                                   Please sit down.     0   
 1      9.62  I declare open the public hearing on the admis...     1   
 2     14.82  The case was lodged on the 11th September 2008...     2   
 3      8.40  The application was allocated to the fourth se...     3   
 4      6.76  The application was communicated to the govern...     4   
 
    role_turn                             file  
 0          1  transcript_4887608_07032012.xml 

In [4]:
import os
from lxml import etree
import pandas as pd

DATA_DIR = "/content/LaCour/"

def normalize_role(role):
    if not role:
        return "OTHER"
    role = role.lower()
    if any(r in role for r in ["applicant", "counsel", "lawyer1"]):
        return "LAWYER1"
    if any(r in role for r in ["government", "agent", "lawyer2"]):
        return "LAWYER2"
    if any(r in role for r in ["judge", "president"]):
        return "JUDGE"
    return "OTHER"

def extract_case_id(filename):
    base = os.path.splitext(filename)[0]
    parts = base.split("_", 1)
    return parts[1] if len(parts) > 1 else base

rows = []

for filename in os.listdir(DATA_DIR):
    if filename.endswith(".xml"):
        file_path = os.path.join(DATA_DIR, filename)
        case_id = extract_case_id(filename)

        try:
            tree = etree.parse(file_path)
        except Exception as e:
            print("Error parsing:", filename, e)
            continue

        segments = tree.xpath("//SpeakerSegment/Segment")

        turn = 0
        role_counter = {}
        previous_role = "NONE"

        merged_role = None
        merged_start = None
        merged_end = None
        merged_text = ""

        for seg in segments:
            lang = seg.findtext("meta_data/Language")
            if lang != "en":
                continue

            role = normalize_role(seg.findtext("meta_data/Role"))
            if role == "OTHER":
                continue

            text = seg.findtext("text")
            if not text:
                continue
            text = text.strip()

            start = seg.findtext("meta_data/TimestampBegin")
            end = seg.findtext("meta_data/TimestampEnd")
            start = float(start) if start else None
            end = float(end) if end else None

            # If same speaker as previous, merge text
            if merged_role == role:
                merged_text += " " + text
                merged_end = end  # update end time
                continue

            # If speaker changed, save previous merged row
            if merged_role is not None:
                duration = merged_end - merged_start if merged_start and merged_end else None
                role_counter[merged_role] = role_counter.get(merged_role, 0) + 1
                role_turn = role_counter[merged_role]

                rows.append({
                    "case_id": case_id,
                    "speaker_role": merged_role,
                    "prev_speaker_role": previous_role,
                    "start_time": merged_start,
                    "end_time": merged_end,
                    "duration": duration,
                    "text": merged_text.strip(),
                    "turn": turn,
                    "role_turn": role_turn,
                    "file": filename
                })

                previous_role = merged_role
                turn += 1

            # Start new merge block
            merged_role = role
            merged_start = start
            merged_end = end
            merged_text = text

        # Save last merged row at end of file
        if merged_role is not None:
            duration = merged_end - merged_start if merged_start and merged_end else None
            role_counter[merged_role] = role_counter.get(merged_role, 0) + 1
            role_turn = role_counter[merged_role]

            rows.append({
                "case_id": case_id,
                "speaker_role": merged_role,
                "prev_speaker_role": previous_role,
                "start_time": merged_start,
                "end_time": merged_end,
                "duration": duration,
                "text": merged_text.strip(),
                "turn": turn,
                "role_turn": role_turn,
                "file": filename
            })

df = pd.DataFrame(rows)
df.to_csv("LaCour_merged.csv", index=False)

df.head(), len(df)


(            case_id speaker_role prev_speaker_role  start_time  end_time  \
 0  4887608_07032012        JUDGE              NONE       52.38    221.22   
 1  4887608_07032012      LAWYER2             JUDGE      225.18   1951.02   
 2  4887608_07032012        JUDGE           LAWYER2     1953.36   1958.52   
 3  4887608_07032012      LAWYER1             JUDGE     1960.81   3771.03   
 4  4887608_07032012        JUDGE           LAWYER1     3775.06   3816.90   
 
    duration                                               text  turn  \
 0    168.84  Please sit down. I declare open the public hea...     0   
 1   1725.84  Madam President, Members of the Court, Message...     1   
 2      5.16  Thank you very much, Mr Chamberlain. Now I cal...     2   
 3   1810.22  Madam President, members of the court, in 2005...     3   
 4     41.84  Thank you very much, Mr. Tomlinson. Now I woul...     4   
 
    role_turn                             file  
 0          1  transcript_4887608_07032012.xml 

#DeepSeek R1

In [5]:
!pip install transformers datasets accelerate bitsandbytes peft -q

In [6]:
import pandas as pd

df = pd.read_csv("LaCour_merged.csv")

# Sort to guarantee correct order
df = df.sort_values(by=["case_id", "turn"]).reset_index(drop=True)

# Build previous text column
prev_texts = []
for i, row in df.iterrows():
    if i == 0 or df.loc[i, "case_id"] != df.loc[i-1, "case_id"]:
        prev_texts.append("None")
    else:
        prev_texts.append(df.loc[i-1, "text"])

df["previous_text"] = prev_texts

df.head()

Unnamed: 0,case_id,speaker_role,prev_speaker_role,start_time,end_time,duration,text,turn,role_turn,file,previous_text
0,1021112_29112017,JUDGE,NONE,50.37,272.37,222.0,Please be seated. I declare open the public he...,0,1,transcript_1021112_29112017.xml,
1,1021112_29112017,LAWYER1,JUDGE,277.99,1173.31,895.32,"Mr. President, distinguished judges of the cou...",1,1,transcript_1021112_29112017.xml,Please be seated. I declare open the public he...
2,1021112_29112017,JUDGE,LAWYER1,1179.22,1182.22,3.0,"Please, Mr. Mavany, you have the floor.",2,2,transcript_1021112_29112017.xml,"Mr. President, distinguished judges of the cou..."
3,1021112_29112017,LAWYER1,JUDGE,1185.8,1962.4,776.6,"Mr. President, Honorable Judges of the Court, ...",3,2,transcript_1021112_29112017.xml,"Please, Mr. Mavany, you have the floor."
4,1021112_29112017,JUDGE,LAWYER1,1964.5,1969.98,5.48,"Thank you very much, Mr. Mavany. I call Profes...",4,3,transcript_1021112_29112017.xml,"Mr. President, Honorable Judges of the Court, ..."


In [7]:
def make_prompt(row):
    previous = row["previous_text"] if isinstance(row["previous_text"], str) else "None"
    return f"""You are participating in a European Court of Human Rights hearing.

Case: {row['case_id']}
Your role: {row['speaker_role']}
Previous speaker: {row['prev_speaker_role']}
Previous text: {previous}

Respond in formal courtroom language, without chain of thought, and stay in character as {row['speaker_role']}.

Response:"""

df["input_text"] = df.apply(make_prompt, axis=1)
df["output_text"] = df["text"]

df[["input_text", "output_text"]].head()

Unnamed: 0,input_text,output_text
0,You are participating in a European Court of H...,Please be seated. I declare open the public he...
1,You are participating in a European Court of H...,"Mr. President, distinguished judges of the cou..."
2,You are participating in a European Court of H...,"Please, Mr. Mavany, you have the floor."
3,You are participating in a European Court of H...,"Mr. President, Honorable Judges of the Court, ..."
4,You are participating in a European Court of H...,"Thank you very much, Mr. Mavany. I call Profes..."


In [8]:
from datasets import Dataset

dataset = Dataset.from_pandas(df[["input_text", "output_text"]])


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "deepseek-ai/deepseek-r1-distill-qwen-7b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map="auto"
)


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
def tokenize(example):
    x = tokenizer(example["input_text"], truncation=True, max_length=2048)
    y = tokenizer(example["output_text"], truncation=True, max_length=512)
    x["labels"] = y["input_ids"]
    return x

tokenized = dataset.map(tokenize, batched=False)


In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./lacour_deepseek_r1_model",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    num_train_epochs=2,
    fp16=True,
    logging_steps=20,
    save_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized
)

trainer.train()


In [None]:
prompt = """You are role LAWYER1.
The previous speaker was JUDGE, who said:
"The Court invites counsel for the applicant to present their submissions."

Respond in formal courtroom style, without chain of thought.

Response:"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(output[0], skip_special_tokens=True))
