In [4]:
import dash
from dash import dcc, html, Input, Output, State, ctx, dash_table
import dash_bootstrap_components as dbc
import base64
import tempfile
import os
import jsonlines
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import re
import json
import hashlib
import nltk
import dash_cytoscape as cyto
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer

nltk.download('punkt')
nltk.download('wordnet')
lemmatizer = WordNetLemmatizer()

app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
app.title = "Prompt-Ontology Triple Extractor"

# ------------------- Normalization -------------------

def clean_triple_component(text):
    text = text.lower().replace("_", " ")
    text = re.sub(r"[^a-z0-9 ]", "", text)
    text = text.strip()
    tokens = word_tokenize(text)
    lemmatized = " ".join(lemmatizer.lemmatize(token) for token in tokens)
    return lemmatized

# ------------------- Model Setup Functions -------------------

def setup_llama_model():
    model_id = "meta-llama/Meta-Llama-3-8B"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
    model.config.use_cache = False
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
    return pipe, tokenizer

def setup_mistral_model():
    model_id = "mistralai/Mistral-7B-Instruct-v0.3"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
    return pipe, tokenizer

# ------------------- Output Parsing -------------------

def extract_test_outputs(response):
    outputs = []
    for res in response:
        generated_text = res.get('generated_text', '')
        match = re.search(r'Test Output:\s*(.*?)(?=\n\s*#|$)', generated_text, re.DOTALL)
        if match:
            outputs.append(match.group(1).strip())
    return outputs if outputs else ['Output not found']

def parse_model_output(model_output):
    triples = []
    pattern = re.compile(r'(.+?)\s*\(([^,]+),\s*([^)]+)\)')
    lines = model_output.strip().split('\n')
    for line in lines:
        for match in pattern.findall(line):
            rel, sub, obj = match
            triples.append({
                "sub": clean_triple_component(sub),
                "rel": clean_triple_component(rel),
                "obj": clean_triple_component(obj)
            })
    return triples

# ------------------- Graph Utilities -------------------

def color_for_node(node_id):
    h = hashlib.md5(node_id.encode()).hexdigest()
    return '#' + h[:6]

def sanitize_id(text, max_length=20):
    sanitized = re.sub(r'\W+', '_', text)
    if len(sanitized) > max_length:
        hashed = hashlib.md5(text.encode()).hexdigest()[:8]
        sanitized = sanitized[:max_length-9] + '_' + hashed
    return sanitized

def triples_to_cytoscape_elements(triples):
    elements, nodes = [], set()
    edge_dict = {}
    for triple in triples:
        s, p, o = triple['sub'], triple['rel'], triple['obj']
        s_id = sanitize_id(s)
        o_id = sanitize_id(o)
        if s_id not in nodes:
            elements.append({'data': {'id': s_id, 'label': s, 'fullLabel': s}, 'style': {'background-color': color_for_node(s)}})
            nodes.add(s_id)
        if o_id not in nodes:
            elements.append({'data': {'id': o_id, 'label': o, 'fullLabel': o}, 'style': {'background-color': color_for_node(o)}})
            nodes.add(o_id)
        key = (s_id, o_id)
        edge_dict.setdefault(key, set()).add((p, s, o))

    processed = set()
    for (src, tgt), rel_info_set in list(edge_dict.items()):
        rels = [rel for rel, _, _ in rel_info_set]
        label = ', '.join(sorted(rels))
        if (tgt, src) in edge_dict and (tgt, src) not in processed:
            rels_opposite = [rel for rel, _, _ in edge_dict[(tgt, src)]]
            label = ', '.join(sorted(set(rels + rels_opposite)))
            elements.append({'data': {'id': f'{src}_bi_{tgt}', 'source': src, 'target': tgt, 'label': label, 'fullLabel': label}, 'classes': 'bidirectional'})
            processed.add((src, tgt))
            processed.add((tgt, src))
        elif (src, tgt) not in processed:
            rel, sub, obj = list(rel_info_set)[0]
            elements.append({'data': {'id': f'{src}_{tgt}_{label}', 'source': src, 'target': tgt, 'label': label, 'fullLabel': label, 'sub': sub, 'rel': label, 'obj': obj}})
            processed.add((src, tgt))
    return elements

# ------------------- UI Layout -------------------

app.layout = dbc.Container([
    html.H2("Triple Extractor using LLaMA-3 / Mistral"),
    dbc.Row([
        dbc.Col(dcc.Upload(id='upload-prompt', children=html.Div(["Upload Prompt JSONL"]), multiple=False), width=6),
        dbc.Col(dcc.Upload(id='upload-ontology', children=html.Div(["Upload Ontology JSON"]), multiple=False), width=6),
    ]),
    html.Br(),
    dbc.Row([
        dbc.Col(dcc.Dropdown(id='model-selector', options=[{"label": "LLaMA 3 (Meta)", "value": "llama"}, {"label": "Mistral 7B", "value": "mistral"}], value="mistral", clearable=False), width=4),
        dbc.Col(dbc.RadioItems(id='view-selector', options=[{"label": "Table View", "value": "table"}, {"label": "Graph View", "value": "graph"}], value="table", inline=True), width=6),
        dbc.Col(dbc.Button("Run Extraction", id="run-btn", color="primary"), width="auto"),
    ]),
    html.Div(id="status", style={"marginTop": 20}),
    html.Hr(),
    dcc.Store(id='triples-store'),
    html.Div(id="result-output"),
    html.Div(id='cytoscape-description', style={'marginTop': '20px'}),
    dcc.Store(id='cy-elements-store'),
])

# ------------------- Callbacks -------------------

@app.callback(
    Output("status", "children"),
    Output("triples-store", "data"),
    Output("cy-elements-store", "data"),
    Input("run-btn", "n_clicks"),
    State("model-selector", "value"),
    State("upload-prompt", "contents"),
    State("upload-prompt", "filename"),
    State("upload-ontology", "contents"),
    State("upload-ontology", "filename"),
    prevent_initial_call=True
)
def run_extraction(n_clicks, selected_model, prompt_content, prompt_name, ontology_content, ontology_name):
    if not prompt_content:
        return "Please upload prompt file.", None, []

    prompt_type, prompt_str = prompt_content.split(',')
    prompt_data = base64.b64decode(prompt_str)
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl") as f:
        f.write(prompt_data)
        prompt_path = f.name

    prompts = list(jsonlines.open(prompt_path))

    generator, tokenizer = setup_llama_model() if selected_model == "llama" else setup_mistral_model()

    results = []
    for idx, item in enumerate(prompts[:10]):
        try:
            response = generator(item["prompt"], max_new_tokens=512, num_return_sequences=2, temperature=0.2)
        except Exception:
            continue

        test_outputs = extract_test_outputs(response)
        all_triples = []
        seen = set()
        for out in test_outputs:
            triples = parse_model_output(out)
            for t in triples:
                key = (t["sub"], t["rel"], t["obj"])
                if key not in seen:
                    seen.add(key)
                    all_triples.append(t)
        results.append({"id": item["id"], "triples": all_triples})

    all_cleaned_triples = []
    for res in results:
        for triple in res["triples"]:
            all_cleaned_triples.append({"ID": res["id"], **triple})

    cy_elements = triples_to_cytoscape_elements(all_cleaned_triples)
    return f"Processed {len(results)} prompts using {selected_model.title()} model.", all_cleaned_triples, cy_elements

@app.callback(
    Output("result-output", "children"),
    Input("view-selector", "value"),
    State("triples-store", "data"),
    State("cy-elements-store", "data"),
    prevent_initial_call=True
)
def update_view(view_mode, triples, elements):
    if view_mode == "table":
        return dash_table.DataTable(
            columns=[{"name": c, "id": c} for c in ["ID", "sub", "rel", "obj"]],
            data=triples or [],
            page_size=10,
            style_table={"overflowX": "auto"}
        )
    else:
        return cyto.Cytoscape(
            id='cytoscape-graph',
            layout={'name': 'cose'},
            style={'width': '100%', 'height': '600px', 'border': '1px solid black'},
            elements=elements or [],
            stylesheet=[
                {'selector': 'node', 'style': {'content': 'data(label)', 'text-valign': 'center', 'text-halign': 'center', 'width': '50px', 'height': '50px', 'font-size': '14px', 'font-weight': 'bold'}},
                {'selector': 'edge', 'style': {'label': 'data(label)', 'curve-style': 'bezier', 'target-arrow-shape': 'triangle', 'arrow-scale': 1.5, 'line-color': '#A9A9A9', 'target-arrow-color': '#A9A9A9', 'font-size': '11px', 'text-rotation': 'autorotate'}},
                {'selector': '.bidirectional', 'style': {'source-arrow-shape': 'triangle', 'target-arrow-shape': 'triangle', 'line-color': '#0074D9', 'source-arrow-color': '#0074D9', 'target-arrow-color': '#0074D9', 'font-weight': 'bold'}}
            ]
        )

@app.callback(
    Output('cytoscape-description', 'children'),
    Input('cytoscape-graph', 'tapNodeData'),
    Input('cytoscape-graph', 'tapEdgeData')
)
def show_element_description(node_data, edge_data):
    if not ctx.triggered:
        return "Click on a node or edge to see details."

    triggered_prop = ctx.triggered[0]["prop_id"]

    if triggered_prop == "cytoscape-graph.tapNodeData" and node_data:
        return html.Div([
            html.B("Node details:"), html.Br(),
            f"Label: {node_data.get('fullLabel', node_data.get('label', ''))}"
        ])
    elif triggered_prop == "cytoscape-graph.tapEdgeData" and edge_data:
        return html.Div([
            html.B("Edge details:"), html.Br(),
            f"Subject: {edge_data.get('sub', edge_data.get('source', ''))}", html.Br(),
            f"Relation(s): {edge_data.get('rel', edge_data.get('label', ''))}", html.Br(),
            f"Object: {edge_data.get('obj', edge_data.get('target', ''))}"
        ])

    return "Click on a node or edge to see details."

if __name__ == "__main__":
    app.run(mode='external', host='0.0.0.0', port=8070, debug=True)


[nltk_data] Downloading package punkt to
[nltk_data]     /upb/users/b/balram/profiles/unix/cs/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /upb/users/b/balram/profiles/unix/cs/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.07s/it]
Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
