Based on https://python.langchain.com/docs/how_to/structured_output/

In [None]:
import random
import itertools
from pprint import pprint
from pathlib import Path
import os
import sys
import json
import re

from dotenv import load_dotenv
from jinja2 import Environment, FileSystemLoader
from tqdm.notebook import tqdm

# from langchain_ibm import ChatWatsonx
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_core.documents import Document

from pydantic import BaseModel, ValidationError, Field
from typing import Any, Optional, List

load_dotenv()
api_key = os.environ['ANTHROPIC_API_KEY']
# base_url = os.environ['WATSONX_URL']
# project_id = os.environ['WATSONX_PROJECT_ID']



In [None]:
id_gen = itertools.count(start=0)

def p(d):
    return f"{d.metadata['id']} ({d.metadata['start']}, {d.metadata['end']})"

def split_text(fn) -> List[Document]:
    pattern = r'\n\s*\n'
    
    with open(fn, "r") as file:
        content = file.read()

    splits = []
    previous_end = 0
    for match in re.finditer(pattern, content):
        start, end = match.span()
        d = Document(
            page_content = content[previous_end:start], 
            metadata = dict(id = f"B{next(id_gen)}", start = previous_end, end = start, filename = fn)
        )
        splits.append(d)  
        previous_end = end
    d = Document(
            page_content = content[previous_end:], 
            metadata = dict(id = f"B{next(id_gen)}", start = previous_end, end = len(content), filename = fn)
    )
    splits.append(d)
    return splits

lhs_file = "../frontend/public/data/SHA-1/selected-text.txt"
rhs_file = "../frontend/public/data/SHA-1/pre-written.txt"
annotations_file = "../frontend/public/data/SHA-1/annotations.json"

lhs_docs = split_text(lhs_file)
rhs_docs = split_text(rhs_file)
len(lhs_docs), len(rhs_docs)

(27, 14)

In [54]:
from collections import defaultdict

lhs, rhs = [], []
content = {}

with open(annotations_file, "r") as file:
    annotations = json.load(file)
with open(lhs_file, 'r') as f:
    content["lhs"] = f.read()
with open(rhs_file, 'r') as f:
    content["rhs"] = f.read()

def create_document(d, idx, side, fn):
    return Document(
        page_content= content[side][d['start']:d['end']], 
        metadata = dict(id = f"F{idx}_{next(id_gen)}", start = d['start'], end = d['end'], filename = fn)
    )

for i, a in enumerate(annotations['mappings']):
    for d in a.get('lhsRanges', []):
        lhs.append( create_document(d, i, "lhs", lhs_file) )
    for d in a.get('rhsRanges', []):
        rhs.append( create_document(d, i, "rhs", rhs_file) )

len(lhs), len(rhs)

(25, 20)

In [55]:
def is_range(a, b):
    ai, ae = a.metadata['start'], a.metadata['end']
    assert ai <= ae
    bi, be = b.metadata['start'], b.metadata['end'] 
    assert bi <= be
    if bi <= ai and ae <= be:
        return 1 # A inside B
    elif ai <= bi and be <= ae:
        return 2 # B inside A
    elif ae < bi or be < ai:
        return 3 # A and B disjoint
    else:
        return 4 # A and B intersect

def coverage(bs, fs):
    f = lambda d: d.metadata['end']
    list1 = sorted(bs, key = f)
    list2 = sorted(fs, key = f)
    blocks = {i.metadata['id']: [] for i in list1}
    for a in list1:
        for b in list2[:]:
            if is_range(b,a) == 1:
                blocks[a.metadata['id']].append(b.metadata['id'])
                list2.remove(b)
    return blocks, [p(r) for r in list2]

coverage(lhs_docs, lhs), coverage(rhs_docs, rhs)

(({'B0': ['F9_66', 'F10_70', 'F7_59'],
   'B1': ['F9_67'],
   'B2': ['F11_73', 'F12_75', 'F13_78', 'F15_84', 'F14_81'],
   'B3': [],
   'B4': ['F10_71'],
   'B5': [],
   'B6': ['F7_60'],
   'B7': [],
   'B8': [],
   'B9': [],
   'B10': [],
   'B11': ['F0_41', 'F1_43', 'F2_46', 'F3_49', 'F4_52'],
   'B12': ['F1_44', 'F2_47', 'F3_50', 'F5_54', 'F6_56'],
   'B13': ['F7_58', 'F9_65', 'F10_69', 'F8_63'],
   'B14': [],
   'B15': [],
   'B16': [],
   'B17': [],
   'B18': [],
   'B19': [],
   'B20': [],
   'B21': [],
   'B22': [],
   'B23': [],
   'B24': [],
   'B25': [],
   'B26': []},
  []),
 ({'B27': [],
   'B28': [],
   'B29': [],
   'B30': ['F7_62'],
   'B31': [],
   'B32': [],
   'B33': ['F11_74',
    'F12_76',
    'F15_85',
    'F13_79',
    'F14_82',
    'F12_77',
    'F13_80',
    'F9_68',
    'F14_83'],
   'B34': ['F10_72'],
   'B35': [],
   'B36': ['F0_42', 'F4_53', 'F8_64', 'F3_51', 'F5_55', 'F7_61'],
   'B37': ['F6_57', 'F1_45', 'F2_48'],
   'B38': [],
   'B39': [],
   'B40': []},

In [None]:
chat = ChatOllama(
    model = "llama3.2:3b",
    temperature=0,
)

chat = ChatAnthropic(
    model = 'claude-3-7-sonnet-20250219',
    temperature=0,
    max_tokens=3000,
    timeout=None,
    max_retries=2,
    api_key= api_key,
)

In [None]:
class Annotation(BaseModel):
    lhsText     : str = Field(description="the fragments of the LHS text")
    rhsText     : str = Field(description="the fragments of the RHS text")
    status      : str = Field(description="default (preserve equivalence), warning (ambiguous), error (incorrect)")

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "*** TEXT {text} \n\n*** CODE {code}"),
        ("ai", "{status}"),
    ]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=[dict(text = random.choice(d.lhsText), 
                   code = random.choice(d.rhsText),
                   status = d.status) for d in examples]
)

with open('ar-prompt.text', 'r') as file:
    instruction = file.read()

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", instruction),
        few_shot_prompt,
        ("human", "*** TEXT {input} \n\n*** CODE {code}"),
    ]
)

structured_chat = chat.with_structured_output(Annotation)
chain = final_prompt | structured_chat 

import jsondiff as jd
from jsondiff import diff, JsonDiffer
import statistics

jd = JsonDiffer()

report = []
for l in lhs[5:10]:
    for r in rhs[5:10]:
        try:
            res = chain.invoke({"input": l, "code": r})
        except Exception as e:
            print(e)
        else:
            print(f"lhs: {l}")
            print(f"rhs: {r}")
            print(res)

lhs: \subsection*{5.2.1 SHA-1, SHA-224 and SHA-256}
For SHA-1, SHA-224 and SHA-256, the message and its padding are parsed into \(N\) 512-bit blocks, \(M^{(1)}, M^{(2)}, \ldots, M^{(N)}\). Since the 512 bits of the input block may be expressed as sixteen 32bit words, the first 32 bits of message block \(i\) are denoted \(M_{0}^{(i)}\), the next 32 bits are \(M_{1}^{(i)}\), and so on up to \(M_{15}^{(i)}\).
rhs: -- Left rotate operation using mathlib's rotateLeft
def ROTL (n : Nat) (x : Word) : Word :=
  let nn : UInt32 := n.toUInt32
  ((x <<< nn) ||| (x >>> (32 - nn)))
description='ROTL operation' lhsText=['ROTL'] rhsText=['Left rotate operation'] status='ok'
lhs: \subsection*{5.2.1 SHA-1, SHA-224 and SHA-256}
For SHA-1, SHA-224 and SHA-256, the message and its padding are parsed into \(N\) 512-bit blocks, \(M^{(1)}, M^{(2)}, \ldots, M^{(N)}\). Since the 512 bits of the input block may be expressed as sixteen 32bit words, the first 32 bits of message block \(i\) are denoted \(M_{0}^{(i