Based on https://python.langchain.com/docs/how_to/structured_output/

In [1]:
import random
import itertools
from pprint import pprint
from pathlib import Path
import os
import sys
import json
import re

from dotenv import load_dotenv
from jinja2 import Environment, FileSystemLoader
from tqdm.notebook import tqdm

# from langchain_ibm import ChatWatsonx
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_core.documents import Document

from pydantic import BaseModel, ValidationError, Field
from typing import Any, Optional, List

load_dotenv()
api_key = os.environ['ANTHROPIC_API_KEY']
# base_url = os.environ['WATSONX_URL']
# project_id = os.environ['WATSONX_PROJECT_ID']



In [2]:
id_gen = itertools.count(start=0)

def p(d):
    return f"{d.metadata['id']} {d.metadata.get('annotation', '')} ({d.metadata['start']}, {d.metadata['end']})"

def split_text(fn) -> List[Document]:
    pattern = r'\n\s*\n'
    
    with open(fn, "r") as file:
        content = file.read()

    splits = {}
    previous_end = 0
    for match in re.finditer(pattern, content):
        start, end = match.span()
        d = Document(
            page_content = content[previous_end:start], 
            metadata = dict(id = f"B{next(id_gen)}", start = previous_end, end = start, filename = fn)
        )
        splits[d.metadata['id']] = d
        previous_end = end
    d = Document(
            page_content = content[previous_end:], 
            metadata = dict(id = f"B{next(id_gen)}", start = previous_end, end = len(content), filename = fn)
    )
    splits[d.metadata['id']] = d
    return splits

lhs_file = "../frontend/public/data/SHA-1/selected-text.txt"
rhs_file = "../frontend/public/data/SHA-1/pre-written.txt"
annotations_file = "../frontend/public/data/SHA-1/annotations.json"

lhs_bs = split_text(lhs_file)
rhs_bs = split_text(rhs_file)
len(lhs_bs), len(rhs_bs)

(27, 14)

In [3]:
from collections import defaultdict

lhs_fs, rhs_fs = {}, {}
content = {}

with open(annotations_file, "r") as file:
    annotations = json.load(file)
with open(lhs_file, 'r') as f:
    content["lhs"] = f.read()
with open(rhs_file, 'r') as f:
    content["rhs"] = f.read()

def create_document(d, idx, side, fn):
    return Document(
        page_content= content[side][d['start']:d['end']], 
        metadata = dict(id = f"F{next(id_gen)}", 
         start = d['start'], end = d['end'], annotation = idx, filename = fn)
    )

for i, a in enumerate(annotations['mappings']):
    for d in a.get('lhsRanges', []):
        doc = create_document(d, i, "lhs", lhs_file)
        lhs_fs[doc.metadata['id']] = doc
    for d in a.get('rhsRanges', []):
        doc = create_document(d, i, "rhs", rhs_file)
        rhs_fs[doc.metadata['id']] = doc

len(lhs_fs), len(rhs_fs)

(25, 20)

In [4]:
def is_range(a, b):
    ai, ae = a.metadata['start'], a.metadata['end']
    assert ai <= ae
    bi, be = b.metadata['start'], b.metadata['end'] 
    assert bi <= be
    if bi <= ai and ae <= be:
        return 1 # A inside B
    elif ai <= bi and be <= ae:
        return 2 # B inside A
    elif ae < bi or be < ai:
        return 3 # A and B disjoint
    else:
        return 4 # A and B intersect

def coverage(bs, fs):
    f = lambda d: d.metadata['end']
    list1 = sorted(bs.values(), key = f)
    list2 = sorted(fs.values(), key = f)
    blocks = {i.metadata['id']: [] for i in list1}
    for a in list1:
        for b in list2[:]:
            if is_range(b,a) == 1:
                blocks[a.metadata['id']].append(b.metadata['id'])
                list2.remove(b)
    assert len(list2) == 0
    return blocks

lhs_map = coverage(lhs_bs, lhs_fs)
rhs_map = coverage(rhs_bs, rhs_fs)
lhs_map

def custom_str(self):
    return f"{self.metadata['id']} {self.metadata.get('annotation','NA')} ({self.metadata['start']},{self.metadata['end']})"

Document.__str__ = custom_str

def print_map(m, db, df):
    for k,v in m.items():
        print(db[k], [f'{df[i]}' for i in v])

# print_map(lhs_map, lhs_bs, lhs_fs)

bmaps = []
for bl, vl in lhs_map.items():
    ansl = [lhs_fs[k].metadata['annotation'] for k in vl]
    for br, vr in rhs_map.items():
        ansr = [rhs_fs[k].metadata['annotation'] for k in vr]
        if set(ansl) & set(ansr):
            bmaps.append( (bl, br) )


"""
for a, b in bmaps:
    print(f'---ANN \n{lhs_bs[a].page_content}\n~~~\n{rhs_bs[b].page_content}\n')
"""
bmaps



[('B0', 'B30'),
 ('B0', 'B33'),
 ('B0', 'B34'),
 ('B0', 'B36'),
 ('B1', 'B33'),
 ('B2', 'B33'),
 ('B4', 'B34'),
 ('B6', 'B30'),
 ('B6', 'B36'),
 ('B11', 'B36'),
 ('B11', 'B37'),
 ('B12', 'B36'),
 ('B12', 'B37'),
 ('B13', 'B30'),
 ('B13', 'B33'),
 ('B13', 'B34'),
 ('B13', 'B36')]

In [5]:
chat = ChatOllama(
    model = "llama3.2:3b",
    temperature=0,
)

chat = ChatAnthropic(
    model = 'claude-3-7-sonnet-20250219',
    temperature=0,
    max_tokens=3000,
    timeout=None,
    max_retries=2,
    api_key= api_key,
)

In [6]:
class Annotation(BaseModel):
    lhsText     : str = Field(description="the fragments of the LHS text")
    rhsText     : str = Field(description="the fragments of the RHS text")

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "*** TEXT {text} \n\n*** CODE {code}"),
        ("ai", "{status}"),
    ]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=[dict(
        text = lhs_bs[m[0]].page_content, 
        code = rhs_bs[m[1]].page_content,
        status = "related"
        ) for m in random.sample(bmaps, 5)]
)

env = Environment(loader=FileSystemLoader('.'))
template = env.get_template('ar-prompt.text')
instruction = template.render(
    lhs = "\n~~~\n".join([b.page_content for b in lhs_bs.values()]), 
    rhs = "\n~~~\n".join([b.page_content for b in rhs_bs.values()])
)


In [None]:
# import jsondiff as jd
# from jsondiff import diff, JsonDiffer
# import statistics
# jd = JsonDiffer()

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", instruction),
        few_shot_prompt,
        ("human", "*** TEXT {input} \n*** CODE {code}"),
    ]
)

# structured_chat = chat.with_structured_output(Annotation)
# chain = final_prompt # | structured_chat 
chain = {"input": RunnablePassthrough(), "code" : RunnablePassthrough()} | few_shot_prompt 

for lk, lv in lhs_bs:
    for rk, rv in rhs_bs:
        try:
            res = chain.invoke({"input": lv.page_content, "code": rv.page_content})
        except Exception as e:
            print(e)
        else:
            print(f"lhs: {lk}  rhs: {rk} response: {res}")


NameError: name 'RunnablePassthrough' is not defined