In [None]:
# !pip install -q -U pillow git+https://github.com/huggingface/transformers git+https://github.com/huggingface/peft.git datasets pymupdf bitsandbytes seaborn matplotlib

## Load Dataset

In [None]:
from datasets import load_dataset

ds = load_dataset("chenghao/sec-material-contracts-qa-splitted")

## Load Model

In [None]:
import torch
from transformers import AutoProcessor, Idefics2ForConditionalGeneration, BitsAndBytesConfig

base_model = "HuggingFaceM4/idefics2-8b"
# peft_model_id = "chenghao/idefics2-edgar"
peft_model_id = "HuggingFaceM4/idefics2-8b"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)
model = Idefics2ForConditionalGeneration.from_pretrained(
    peft_model_id,
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

model.eval()

## Evaluation

In [None]:
from tqdm.auto import tqdm
import time

processor = AutoProcessor.from_pretrained(
    base_model,
    do_image_splitting=False,
    size={"longest_edge": 490, "shortest_edge": 350}
)

def evaluate():
    
    questions = []
    answers = []
    predictions = []
    start_time = time.time()
    
    for example in tqdm(ds["test"]):
        images = example["images"][:2] + example["images"][-2:]
        question, answer = example["question"], example["answer"]
        
        messages = [
            {
                "role": "user",
                "content": [{"type": "image"} for _ in range(len(images))] + [{"type": "text", "text": question}],
            },
        ]

        prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(text=prompt, images=images, return_tensors="pt").to("cuda")
        with torch.no_grad():
            generated_ids = model.generate(**inputs, max_new_tokens=1024)
        generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
        preds = [t.split("Assistant:", 1)[-1].strip() for t in generated_texts]

        questions.append(question)
        answers.append(answer)
        predictions.append(preds[0])
    
    return questions, answers, predictions

questions, answers, predictions = evaluate()

In [None]:
# answers[:10], predictions[:10]

In [None]:
from nltk import edit_distance
import numpy as np
from collections import defaultdict

key2question = {
    "agreement_date": "When is the signing date of this agreement?",
    "effective_date": "When is the effective date of the contract?",
    "expiration_date": "When is the service end date or expiration date of the contract?",
    "party_address": "What is the address of the party to the contract?",
    "party_name": "What are the names of the contracting party?",
    "counterparty_address": "What is the address of the counterparty to the contract?",
    "counterparty_name": "What are the names of the contracting counterparty?",
    "counterparty_signer_name": "What is the name of the counterparty signer for each party to the agreement?",
    "counterparty_signer_title": "What is the counterparty signer’s title?",
    "auto_renewal": "Whether the contract term automatically renews (true/false).",
    "governing_law": "Where is the jurisdiction or choice of law?",
    "venue": "where is the location of the courts where legal proceedings will take place?",
    "payment_frequency": "what is the cadence for which payments are made (e.g., monthly, annually, one-time)?",
    "payment_term": "When an invoice is due after issuance (e.g. Net 30)?",
    "renewal_term": "What is the length of time the renewal period will last (e.g., 1 year, 2 years, 24 months etc.)?",
    "agreement_term": "What is the term of the contract as an amount of time (e.g., 24 months)?",
    "termination_for_cause": "Whether one or all parties may terminate the contract with cause, such as a breach of contract (true/false).",
    "termination_for_convenience": "Whether one or all parties may terminate the contract without cause, or at their convenience (true/false).",
    "termination_notice_period": "What is the period by which notice of termination must be given (e.g., 30 days)?",
    "opt_out_length": "What is the required notice period to NOT renew (e.g., 30 days)?",
    "contract_value": "What is the total fixed fee amount including currency codes or symbols?",
}
question2key = {q: k for k, q in key2question.items()}

def calculate_edit_distance(questions, answers, predictions):
    scores = defaultdict(list)
    for question, pred, answer in tqdm(zip(questions, predictions, answers), total=len(answers)):
        if question not in question2key:
            continue
        question = question2key[question]
        scores[question].append(edit_distance(pred, answer) / max(len(pred), len(answer)))
    return {question: np.mean(values) for question, values in scores.items()}, np.mean([v for values in scores.values() for v in values])

In [None]:
# import json

# with open("base.json", "w") as f:
#     f.write(json.dumps({
#         "questions": questions,
#         "answers": answers,
#         "predictions": predictions,
#     }))

In [None]:
# {'counterparty_name': 0.6333587323616815,
 # 'effective_date': 0.12564102564102564,
 # 'payment_frequency': 0.3737244897959184,
 # 'agreement_date': 0.0999478526464782,
 # 'governing_law': 0.30877104014358914,
 # 'party_address': 0.6083013068679771,
 # 'termination_for_cause': 0.04800000000000001,
 # 'venue': 0.614029648733949,
 # 'party_name': 0.4901937256717742,
 # 'renewal_term': 0.05952380952380952,
 # 'expiration_date': 0.23519736842105257,
 # 'auto_renewal': 0.05161290322580645,
 # 'counterparty_signer_name': 0.4804435106679111,
 # 'termination_for_convenience': 0.1565217391304348,
 # 'agreement_term': 0.4388158040994175,
 # 'termination_notice_period': 0.1783935413245758,
 # 'counterparty_signer_title': 0.4960412079943858,
 # 'counterparty_address': 0.5983504545337417,
 # 'contract_value': 0.4188154703882048,
 # 'payment_term': 0.5933333333333334,
 # 'opt_out_length': 0.047619047619047616}

# {'counterparty_name': 0.8254906089175902,
#  'effective_date': 0.9032678569638961,
#  'payment_frequency': 0.6861225538105238,
#  'agreement_date': 0.8784888955727604,
#  'governing_law': 0.8810370791576468,
#  'party_address': 0.7308967559080009,
#  'termination_for_cause': 0.43599999999999994,
#  'venue': 0.7814165648930487,
#  'party_name': 0.7264106186433235,
#  'renewal_term': 0.928290301068474,
#  'expiration_date': 0.8867295058561521,
#  'auto_renewal': 0.6349462365591401,
#  'counterparty_signer_name': 0.8420912525335033,
#  'termination_for_convenience': 0.6282608695652173,
#  'agreement_term': 0.9070674746697442,
#  'termination_notice_period': 0.3297482211275315,
#  'counterparty_signer_title': 0.6174595001757442,
#  'counterparty_address': 0.7713865641438143,
#  'contract_value': 0.47443779675428815,
#  'payment_term': 0.8545520817994479,
#  'opt_out_length': 0.43154761904761907}

## Comparison

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import json
from tqdm import tqdm

def calculate_df(file1, file2):
    
    with open(file1) as f:
        d1 = json.load(f)
        r1, a1 = calculate_edit_distance(d1['questions'], d1['answers'], d1['predictions'])
    
    
    with open(file2) as f:
        d2 = json.load(f)
        r2, a2 = calculate_edit_distance(d2['questions'], d2['answers'], d2['predictions'])
    
    print(f"{(a1-a2)/a1*100:.2f}%")
    records = []
    name = {
        "base": "Idefics2-8B",
        "finetuned": "Idefics2-8B-EDGAR"
    }
    for key, value in r1.items():
        n = name[file1.replace(".json", "")]
        records.append((key, value, n))
    for key, value in r2.items():
        n = name[file2.replace(".json", "")]
        records.append((key, value, n))
    
    return pd.DataFrame(records, columns=["category", "value", "model"])


sns.set(font_scale=1.4)
fig, ax = plt.subplots(figsize=(20, 20))
df = calculate_df("base.json", "finetuned.json")
sns.barplot(y=df.category, x=df.value, hue=df.model, ax=ax)
ax.set_xlabel("Average Edit Distance (lower is better)")
ax.set_ylabel("Category")

ax.set_title("Comparison between base and finetuned model")

In [None]:
data = []
for category, row in df.groupby("category"):
    base = row[row["model"] == "Idefics2-8B"]["value"].iloc[0]
    after = row[row["model"] != "Idefics2-8B"]["value"].iloc[0]
    delta = (base - after) / base
    data.append((category, base, after, f"{delta * 100:.2f}%"))

print(pd.DataFrame(data, columns=["Category", "Idefics2-8B", "Idefics2-8B-EDGAR", "Δ(↑)"]).to_markdown())