In [None]:
# !pip install -q datasets pymupdf openai python-dotenv langchain langchain-openai pillow trafilatura weasyprint

## Load Raw Dataset

Load the contracts with raw HTML code.

In [None]:
import datetime
import time
from urllib.parse import urljoin

import fitz
import tqdm
from bs4 import BeautifulSoup
from datasets import load_dataset
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from trafilatura import html2txt
from weasyprint import HTML

source_repo = "chenghao/sec-material-contracts"
dataset = load_dataset(source_repo)

### Convert HTML to PDF and Images

In [None]:
def convert_relative_url_to_absolute_url(base_url, html):
    soup = BeautifulSoup(html, "html.parser")
    for tag in soup.find_all("a", href=True):
        tag["href"] = urljoin(base_url, tag["href"])
    for tag in soup.find_all("img", src=True):
        tag["src"] = urljoin(base_url, tag["src"])
    return str(soup)


def convert_html_to_pdf(html):
    doc = fitz.Document(stream=io.BytesIO(HTML(string=html).write_pdf()), filetype="pdf")
    images = [page.get_pixmap(dpi=120) for page in doc]
    images = [Image.frombytes("RGB", (image.width, image.height), image.samples) for image in images]
    page_text = [page.get_text() for page in doc]

    return page_text, images

In [None]:
def extract_text(record):
    file_content = record['file_content']

    if not file_content:
        return {"full_text": "", "images": [], "page_text": [], "html_content": ""}

    file_content_lower = file_content.lower()
    if "<html>" not in file_content_lower or "</html>" not in file_content_lower:
        return {"full_text": "", "images": [], "page_text": [], "html_content": ""}

    left = file_content_lower.index("<html>")
    right = file_content_lower.index("</html>")
    html_content = file_content[left + len("<html>"):right]

    full_text = html2txt(html_content)
    index_url = record['index_html_url']
    base_url = index_url.replace("-index.html", "").replace("-", "") + "/"
    html_content = convert_relative_url_to_absolute_url(base_url, html_content)
    page_text, images = convert_html_to_pdf(html_content)

    return {"full_text": full_text, "images": images, "page_text": page_text, "html_content": html_content}

### Sample Documents

In [None]:
cutoff = datetime.datetime(2024, 1, 1)
sample = dataset['train'].filter(lambda x: x['date'] >= cutoff)
sample = sample.map(lambda x: extract_text(x), num_proc=8)

In [None]:
final_sample = sample.filter(
    lambda x: x['full_text'] and 1 <= len(x['page_text']) <= 20 and min(map(len, x['page_text'])) >= 50, num_proc=8)

### Extract Key Information with OpenAI

In [None]:
# https://support.ironcladapp.com/hc/en-us/articles/12947738534935-Ironclad-AI-Overview
class KeyInformation(BaseModel):
    agreement_date: str = Field(description="Agreement signing date of the contract. (date)")
    effective_date: str = Field(description="Effective date of the contract. (date)")
    expiration_date: str = Field(description="Service end date or expiration date of the contract. (date)")
    party_address: str = Field(description="Address of the party to the contract.")
    party_name: str = Field(description="The names of the contracting party.")
    counterparty_address: str = Field(description="Address of the counterparty to the contract.")
    counterparty_name: str = Field(description="The names of the contracting counterparty.")
    counterparty_signer_name: str = Field(
        description="The name of the counterparty signer for each party to the agreement.")
    counterparty_signer_title: str = Field(description="The counterparty signer’s title (e.g., CEO).")
    auto_renewal: str = Field(description="Whether the contract term automatically renews (true/false).")
    governing_law: str = Field(description="(Jurisdiction) Choice of law.")
    venue: str = Field(description="Location of the courts where legal proceedings will take place.")
    payment_frequency: str = Field(
        description="The cadence for which payments are made (e.g., monthly, annually, one-time).")
    payment_term: str = Field(description="When an invoice is due after issuance (e.g. Net 30)")
    renewal_term: str = Field(
        description="The length of time the renewal period will last (e.g., 1 year, 2 years, 24 months etc.).")
    agreement_term: str = Field(description="Term of the contract as an amount of time (e.g., 24 months).")
    termination_for_cause: str = Field(
        description="Whether one or all parties may terminate the contract with cause, such as a breach of contract (true/false).")
    termination_for_convenience: str = Field(
        description="Whether one or all parties may terminate the contract without cause, or at their convenience (true/false).")
    termination_notice_period: str = Field(
        description="The period by which notice of termination must be given (e.g., 30 days).")
    opt_out_length: str = Field(description="Required notice period to NOT renew (e.g., 30 days).")
    contract_value: str = Field(
        description="Total fixed fee amount including currency codes or symbols. (monetary amount)")


load_dotenv()
model = ChatOpenAI(model="gpt-4o", max_retries=2).with_structured_output(KeyInformation)
chain = model

In [None]:
def text2qa(record):
    page_text = record['page_text']
    content = "\n\n".join(page_text)

    messages = [
        SystemMessage(
            content='''You are a legal expert who is helping a client understand a contract. The client asks you to extract the key information for the given contract and return them in a structured format. Use N/A if not applicable or not available.'''),
        HumanMessage(content=content),
    ]

    results = chain.invoke(messages)

    return dict(results)

In [None]:
# text2qa(final_sample[0])

In [None]:
data = final_sample
results = [None for _ in data]
errors = list(range(len(data)))
while errors:
    new_errors = []
    for i in tqdm.tqdm(errors):
        try:
            qa = text2qa(data[i])
            results[i] = qa
            time.sleep(5)

        except Exception as e:
            print(f"Error at {i}: {e}")
            new_errors.append(i)
            continue

    errors = new_errors
    cont = input(f"Found {len(errors)} errors. Continue? (y/n)")
    if cont.lower() != "y":
        break

In [None]:
output = []
for i, record in enumerate(final_sample):
    output.append(record | results[i])

import pandas as pd
from datasets import Dataset

df = pd.DataFrame(output)
ds = Dataset.from_pandas(df)
ds.save_to_disk("temp-data")

In [None]:
ds.push_to_hub("chenghao/sec-material-contracts-qa")

In [None]:
# !pip install -q transformers tokenizers quanto accelerate bitsandbytes -U

## Quick test of Idefics2 on the Dataset

In [None]:
import io

import torch

from datasets import load_from_disk
from PIL import Image
from transformers import Idefics2Processor, Idefics2ForConditionalGeneration
from IPython.display import display

ds = load_from_disk("temp-data")

In [None]:
device = "mps"
model_id = "HuggingFaceM4/idefics2-8b"
model = Idefics2ForConditionalGeneration.from_pretrained(model_id, device_map=device, torch_dtype=torch.float16)
model.eval()

In [None]:
idx = 0
images = [Image.open(io.BytesIO(ds[idx]['images'][i]['bytes'])) for i in range(len(ds[idx]['images']))]
answer = ds[idx]['counterparty_signer_name']

messages = [{
    "role": "user",
    "content": [
        *[{"type": "image"} for _ in range(len(images))],
        {"type": "text", "text": "Who is the counter party signer?"},
    ],
}]
processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b")
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(text, answer)

display(*images)

In [None]:
with torch.no_grad():
    inputs = processor(images=images, text=text, return_tensors="pt").to(device)
    generated_text = model.generate(**inputs, max_new_tokens=50)
    generated_text = processor.batch_decode(generated_text, skip_special_tokens=True)[0]
    print("Generated text:", generated_text)