In [1]:

from entity_retrieval import surface_index_memory
from itertools import product
from SPARQLWrapper import SPARQLWrapper, JSON
from tqdm import tqdm
import ujson
import re
import requests
from sklearn.preprocessing import normalize


In [2]:
from nor_to_sexpr import convert_s_expression_to_sparql

In [3]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\TOPU\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [4]:
surface_index = surface_index_memory.EntitySurfaceIndexMemory(
    "vi_entity_list_file_wikidata_complete_all_mention", "vi_entity_surface_map_file_wikidata_complete_all_mention","vi_wiki_complete_all_mention" )

INFO:entity_retrieval.surface_index_memory:Loading entity vocabulary from disk.
INFO:entity_retrieval.surface_index_memory:Loading surfaces from disk.
INFO:entity_retrieval.surface_index_memory:Done initializing surface index.


In [5]:
def is_valid_expression(expr):
    """Ki·ªÉm tra t√≠nh h·ª£p l·ªá c·ªßa bi·ªÉu th·ª©c b·∫±ng c√°ch ƒë·∫øm s·ªë ngo·∫∑c m·ªü v√† ƒë√≥ng."""
    count = 0
    for char in expr:
        if char == '(':
            count += 1
        elif char == ')':
            count -= 1
        if count < 0:
            return False  # G·∫∑p ngo·∫∑c ƒë√≥ng tr∆∞·ªõc ngo·∫∑c m·ªü
    return count == 0

def fix_unbalanced_parentheses(expr):
    """Lo·∫°i b·ªè ngo·∫∑c ƒë√≥ng d∆∞ n·∫øu c√≥."""
    while not is_valid_expression(expr) and expr.endswith(')'):
        expr = expr[:-1]
    return expr

In [6]:
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
model = AutoModel.from_pretrained("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")


def get_embedding(text):
    """Tr√≠ch xu·∫•t vector embedding t·ª´ m√¥ h√¨nh"""
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        outputs = model(**inputs)

    # Mean pooling
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    return embeddings

def find_entity_or_relation(label, label_map, facc1_index, top_k=50, similarity_threshold=0.2):
    """
    T√¨m th·ª±c th·ªÉ ho·∫∑c quan h·ªá t·ª´ gold maps, SimCSE ho·∫∑c FACC1.
    """
    label_lower = label.lower()

    # N·∫øu c√≥ trong label_map, tr·∫£ v·ªÅ ngay
    if label_map and label_lower in label_map:
        return label_map[label_lower]

    # L·∫•y embedding cho label
    label_embedding = get_embedding(label_lower).reshape(1, -1)

    if label_map:
        label_keys = list(label_map.keys())
        label_embeddings = np.array([get_embedding(k) for k in label_keys]).squeeze(1)

        # ƒê·∫£m b·∫£o ƒë√∫ng shape
        if len(label_embeddings.shape) == 1:
            label_embeddings = label_embeddings.reshape(1, -1)

        # T√≠nh cosine similarity
        similarities = cosine_similarity(label_embedding, label_embeddings).flatten()

        # Ch·ªçn th·ª±c th·ªÉ g·∫ßn nh·∫•t
        merged_list = list(zip(label_keys, similarities))
        sorted_list = sorted(merged_list, key=lambda x: x[1], reverse=True)

        if sorted_list and sorted_list[0][1] > similarity_threshold:
            return label_map[sorted_list[0][0]]

    # N·∫øu kh√¥ng t√¨m th·∫•y, th·ª≠ trong KB (FACC1)
    facc1_cand_entities = facc1_index.get_indexrange_entity_el_pro_one_mention(label_lower, top_k=top_k)
    if facc1_cand_entities:
        best_match = max(facc1_cand_entities.items(), key=lambda x: x[1])
        return best_match[0]

    return label  # Tr·∫£ v·ªÅ label n·∫øu kh√¥ng t√¨m th·∫•y

def parse_nsexpr(expr):
    """
    Chuy·ªÉn chu·ªói bi·ªÉu th·ª©c th√†nh c√¢y c·∫•u tr√∫c d·∫°ng nested list.
    H√†m n√†y d√πng duy·ªát k√Ω t·ª±, khi g·∫∑p '(' s·∫Ω t√¨m ph·∫ßn con cho ƒë·∫øn khi kh·ªõp v·ªõi ')',
    v√† gi·ªØ nguy√™n n·ªôi dung trong ngo·∫∑c vu√¥ng.
    """
    tokens = []
    i = 0
    while i < len(expr):
        if expr[i].isspace():
            i += 1
        elif expr[i] == '(':
            # T√¨m ph·∫ßn con c·ªßa bi·ªÉu th·ª©c trong ngo·∫∑c ƒë∆°n
            count = 1
            j = i + 1
            while j < len(expr) and count > 0:
                if expr[j] == '(':
                    count += 1
                elif expr[j] == ')':
                    count -= 1
                j += 1
            # ƒê·ªá quy ph√¢n t√≠ch ph·∫ßn con (lo·∫°i b·ªè ngo·∫∑c bao ngo√†i)
            subtree = parse_nsexpr(expr[i+1:j-1])
            tokens.append(subtree)
            i = j
        elif expr[i] == '[':
            # Gi·ªØ nguy√™n n·ªôi dung trong ngo·∫∑c vu√¥ng
            j = expr.find(']', i)
            if j == -1:
                return ""
                raise ValueError("Kh√¥ng t√¨m th·∫•y d·∫•u ']' k·∫øt th√∫c.")
                
            token = expr[i:j+1].strip()
            tokens.append(token)
            i = j + 1
        else:
            # ƒê·ªçc m·ªôt token cho ƒë·∫øn khi g·∫∑p kho·∫£ng tr·∫Øng ho·∫∑c ngo·∫∑c
            j = i
            while j < len(expr) and (not expr[j].isspace()) and expr[j] not in ['(', ')']:
                j += 1
            tokens.append(expr[i:j])
            i = j
    return tokens

def collect_labels(tree):
    """
    Duy·ªát c√¢y c·∫•u tr√∫c (nested list) ƒë·ªÉ thu th·∫≠p c√°c nh√£n c·ªßa quan h·ªá v√† th·ª±c th·ªÉ.
    Gi·∫£ s·ª≠:
      - Bi·ªÉu th·ª©c JOIN c√≥ d·∫°ng: ["JOIN", relation_part, entity_part]
      - Ph·∫ßn relation_part: n·∫øu l√† list v√† b·∫Øt ƒë·∫ßu b·∫±ng "R", th√¨ ph·∫ßn th·ª© hai ch·ª©a nh√£n quan h·ªá (d·∫°ng "[ label ]"). 
        N·∫øu l√† chu·ªói d·∫°ng "[ label ]" th√¨ ƒë√≥ c≈©ng l√† nh√£n quan h·ªá.
      - Ph·∫ßn entity_part: n·∫øu l√† chu·ªói d·∫°ng "[ label ]" th√¨ ƒë√≥ l√† nh√£n th·ª±c th·ªÉ, n·∫øu l√† list th√¨ x·ª≠ l√Ω ƒë·ªá quy.
      - Bi·ªÉu th·ª©c AND s·∫Ω c√≥ nhi·ªÅu bi·ªÉu th·ª©c con.
    """
    relations = []
    entities = []
    
    if isinstance(tree, list) and tree:
        # N·∫øu token ƒë·∫ßu ti√™n l√† JOIN ho·∫∑c AND
        op = tree[0]
        if isinstance(op, str):
            op_upper = op.upper()
        else:
            op_upper = ""
        
        if op_upper == "JOIN":
            # X·ª≠ l√Ω ph·∫ßn quan h·ªá
            if len(tree) >= 2:
                rel_part = tree[1]
                # N·∫øu l√† list d·∫°ng [ "R", "[ label ]" ]
                if isinstance(rel_part, list) and len(rel_part) >= 2 and isinstance(rel_part[0], str) and rel_part[0].upper() == "R":
                    token = rel_part[1]
                    if isinstance(token, str) and token.startswith('[') and token.endswith(']'):
                        rel_label = token[1:-1].strip()
                        relations.append(rel_label)
                # N·∫øu l√† chu·ªói d·∫°ng "[ label ]"
                elif isinstance(rel_part, str) and rel_part.startswith('[') and rel_part.endswith(']'):
                    rel_label = rel_part[1:-1].strip()
                    relations.append(rel_label)
                else:
                    # N·∫øu kh√¥ng ƒë√∫ng ƒë·ªãnh d·∫°ng, duy·ªát ƒë·ªá quy
                    sub_rel, sub_ent = collect_labels(rel_part)
                    relations.extend(sub_rel)
                    entities.extend(sub_ent)
            # X·ª≠ l√Ω ph·∫ßn th·ª±c th·ªÉ
            if len(tree) >= 3:
                ent_part = tree[2]
                if isinstance(ent_part, list):
                    sub_rel, sub_ent = collect_labels(ent_part)
                    relations.extend(sub_rel)
                    entities.extend(sub_ent)
                elif isinstance(ent_part, str) and ent_part.startswith('[') and ent_part.endswith(']'):
                    ent_label = ent_part[1:-1].strip()
                    entities.append(ent_label)
        elif op_upper == "AND":
            # V·ªõi AND, duy·ªát t·∫•t c·∫£ c√°c ph·∫ßn con
            for sub in tree[1:]:
                sub_rel, sub_ent = collect_labels(sub)
                relations.extend(sub_rel)
                entities.extend(sub_ent)
        else:
            # N·∫øu kh√¥ng ph·∫£i JOIN hay AND, duy·ªát t·∫•t c·∫£ c√°c ph·∫ßn t·ª≠ n·∫øu ch√∫ng l√† list
            for elem in tree:
                if isinstance(elem, list):
                    sub_rel, sub_ent = collect_labels(elem)
                    relations.extend(sub_rel)
                    entities.extend(sub_ent)
    return relations, entities



def extract_entities_and_relations(normed_expr):

    if not normed_expr or len(normed_expr) == 0:  # Ki·ªÉm tra n·∫øu normed_expr r·ªóng
        return [], []
    
    if normed_expr[0] != "(":
        return [], []
    
    tree = parse_nsexpr(normed_expr)
    if tree is None:
        return [], []  # Tr·∫£ v·ªÅ danh s√°ch r·ªóng n·∫øu parse th·∫•t b·∫°i
    
    return collect_labels(tree)



  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def find_entity(label, label_map, facc1_index, top_k=20, similarity_threshold=0.2):
    """
    T√¨m th·ª±c th·ªÉ ho·∫∑c quan h·ªá t·ª´ gold maps, SimCSE ho·∫∑c FACC1.
    """
    label_lower = label.lower()

    # N·∫øu c√≥ trong label_map, tr·∫£ v·ªÅ ngay
    if label_map and label_lower in label_map:
        return label_map[label_lower]

    # L·∫•y embedding cho label
    label_embedding = get_embedding(label_lower).reshape(1, -1)

    if label_map:
        label_keys = list(label_map.keys())
        label_embeddings = np.array([get_embedding(k) for k in label_keys]).squeeze(1)
        label_embedding = normalize(label_embedding, axis=1)
        label_embeddings = normalize(label_embeddings, axis=1)
        # T√≠nh cosine similarity
        similarities = cosine_similarity(label_embedding, label_embeddings).flatten()
        # Ch·ªçn th·ª±c th·ªÉ g·∫ßn nh·∫•t
        merged_list = list(zip(label_keys, similarities))
        sorted_list = sorted(merged_list, key=lambda x: x[1], reverse=True)

        if sorted_list and sorted_list[0][1] > similarity_threshold:
            return label_map[sorted_list[0][0]]

    # N·∫øu kh√¥ng t√¨m th·∫•y, th·ª≠ trong KB (FACC1)
    facc1_cand_entities = facc1_index.get_indexrange_entity_el_pro_one_mention(label_lower, top_k=top_k)
    if facc1_cand_entities:
        temp = []
        for key in list(facc1_cand_entities.keys())[1:]:
            if facc1_cand_entities[key] >= 0.001:
                temp.append(key)
        if len(temp) > 0:
            label = [list(facc1_cand_entities.keys())[0]]+temp
        else:
            label = list(facc1_cand_entities.keys())[0]

    return label  # Tr·∫£ v·ªÅ label n·∫øu kh√¥ng t√¨m th·∫•y

In [8]:
def find_relation(label, label_map, relation_kb_map, facc1_index, top_k=20, similarity_threshold=0.2):
    """
    T√¨m quan h·ªá t·ª´ gold maps, SimCSE ho·∫∑c FACC1.
    """
    label_lower = label.lower()

    # üîπ B∆Ø·ªöC 1: Ki·ªÉm tra n·∫øu ƒë√£ c√≥ s·∫µn trong `label_map`
    if label_map and label_lower in label_map:
        return label_map[label_lower]

    # üîπ B∆Ø·ªöC 2: Ki·ªÉm tra n·∫øu c√≥ s·∫µn trong `relation_kb_map`
    if relation_kb_map and label_lower in relation_kb_map:
        return relation_kb_map[label_lower]  # ‚è© Tr·∫£ v·ªÅ ngay n·∫øu c√≥ s·∫µn

    # üîπ B∆Ø·ªöC 3: T√≠nh to√°n embedding v√† t√¨m quan h·ªá g·∫ßn nh·∫•t b·∫±ng cosine similarity
    label_embedding = get_embedding(label_lower).reshape(1, -1)

    if label_map:
        label_keys = list(label_map.keys())
        label_embeddings = np.array([get_embedding(k) for k in label_keys]).squeeze(1)
        label_embedding = normalize(label_embedding, axis=1)
        label_embeddings = normalize(label_embeddings, axis=1)

        similarities = cosine_similarity(label_embedding, label_embeddings).flatten()
        sorted_list = sorted(zip(label_keys, similarities), key=lambda x: x[1], reverse=True)

        if sorted_list and sorted_list[0][1] > similarity_threshold:
            return label_map[sorted_list[0][0]]

    if relation_kb_map:
        label_keys = list(relation_kb_map.keys())
        label_embeddings = np.array([get_embedding(k) for k in label_keys]).squeeze(1)
        label_embedding = normalize(label_embedding, axis=1)
        label_embeddings = normalize(label_embeddings, axis=1)

        similarities = cosine_similarity(label_embedding, label_embeddings).flatten()
        sorted_list = sorted(zip(label_keys, similarities), key=lambda x: x[1], reverse=True)

        if sorted_list and sorted_list[0][1] > similarity_threshold:
            return relation_kb_map[sorted_list[0][0]]

    return label

In [9]:
class SExpressionParser:
    def __init__(self):
        self.var_counter = 1  # ƒê·∫øm s·ªë bi·∫øn trung gian (?X1, ?X2, ...)

    def get_new_var(self):
        """T·∫°o bi·∫øn trung gian m·ªõi."""
        var_name = f"?X{self.var_counter}"
        self.var_counter += 1
        return var_name

    def parse_s_expr(self, s_expr):
        """Chuy·ªÉn ƒë·ªïi S-Expression th√†nh danh s√°ch l·ªìng nhau."""
        s_expr = re.sub(r'\(', ' ( ', s_expr)
        s_expr = re.sub(r'\)', ' ) ', s_expr)
        tokens = s_expr.split()
        return self.build_tree(tokens)

    def build_tree(self, tokens):
        """Chuy·ªÉn ƒë·ªïi danh s√°ch token th√†nh c√¢y l·ªìng nhau."""
        if not tokens:
            return None
        token = tokens.pop(0)
        if token == "(":
            sub_expr = []
            while tokens[0] != ")":
                sub_expr.append(self.build_tree(tokens))
            tokens.pop(0)  # B·ªè d·∫•u ")"
            return sub_expr
        elif token == ")":
            raise ValueError("Unexpected ')'")
        else:
            return token

    def process_join(self, expr, target_var):
        """
        X·ª≠ l√Ω JOIN, t·∫°o triple SPARQL.
        """
        triples = []
        if not isinstance(expr, list):
            return expr, triples

        if expr[0] == "AND":
            # X·ª≠ l√Ω t·ª´ng JOIN trong AND ri√™ng l·∫ª
            for sub_expr in expr[1:]:
                _, sub_triples = self.process_join(sub_expr, target_var)
                triples.extend(sub_triples)
            return target_var, triples

        if expr[0] == "JOIN":
            right_expr = expr[2]
            right_triples = []
            if isinstance(right_expr, list) and right_expr[0] == "JOIN":
                right_var, right_triples = self.process_join(right_expr, self.get_new_var())
            else: 
                right_var = right_expr
            
            # X·ª≠ l√Ω nh√°nh tr√°i
            left_expr = expr[1]
            if isinstance(left_expr, list) and left_expr[0] == "R":
                rel = left_expr[1]
                right = right_var
                left = target_var
                if right[0] != '?':
                    right = "wd:" + right
                if left[0] !='?':
                    left = "wd:" + left   
                triples.append([right, f"wdt:{rel}", left])
            else:
                right = right_var
                left = target_var
                if right[0] != '?':
                    right = "wd:" + right
                if left[0] !='?':
                    left = "wd:" + left   
                triples.append([left, f"wdt:{left_expr}", right])

            # Th√™m c√°c triples t·ª´ nh√°nh ph·∫£i tr∆∞·ªõc khi th√™m triple ch√≠nh
            triples = right_triples + triples
            return target_var, triples

        return expr, triples

    def s_expr_to_sparql(self, s_expr):
        """Chuy·ªÉn ƒë·ªïi t·ª´ S-Expression sang SPARQL."""
        if s_expr.count("(") != s_expr.count(")"):
            return None
        if s_expr.count("[") != s_expr.count("]"):
            return None
        parsed_expr = self.parse_s_expr(s_expr)
        target_var = "?answer"
        final_var, triples = self.process_join(parsed_expr, target_var)

        sparql_body = "\n  ".join([" ".join(t) + " ." for t in triples])
        sparql_query = f"""PREFIX wd: <http://www.wikidata.org/entity/> 
PREFIX wdt: <http://www.wikidata.org/prop/direct/> 
SELECT DISTINCT {target_var} WHERE {{ 
  {sparql_body}
}}"""
        return sparql_query

import requests    
WIKIDATA_SPARQL_ENDPOINT = "https://query.wikidata.org/sparql"
WIKIDATA_API_ENDPOINT = "https://www.wikidata.org/w/api.php"

def execute_query_with_odbc(sparql_query):
    """Truy v·∫•n Wikidata v√† tr·∫£ v·ªÅ danh s√°ch c√¢u tr·∫£ l·ªùi (bao g·ªìm t·∫•t c·∫£ bi·∫øn)"""
    headers = {"User-Agent": "Mozilla/5.0", "Accept": "application/sparql-results+json"}
    response = requests.get(WIKIDATA_SPARQL_ENDPOINT, params={"query": sparql_query, "format": "json"}, headers=headers)

    if response.status_code == 200:
        results = response.json().get("results", {}).get("bindings", [])
        answers = []

        for result in results:
            for var in result:  # Duy·ªát qua t·∫•t c·∫£ c√°c bi·∫øn tr·∫£ v·ªÅ
                value = result[var]["value"]
                answers.append(value)  # Ch·∫•p nh·∫≠n t·∫•t c·∫£ gi√° tr·ªã, kh√¥ng ch·ªâ th·ª±c th·ªÉ Wikidata

        return answers  # Tr·∫£ v·ªÅ to√†n b·ªô danh s√°ch k·∫øt qu·∫£

    return []


In [10]:
def convert_normed_to_s_expression(normed_expr, gold_relation_map, gold_entity_map, relation_KB_map, facc1_index):
    """
    Chuy·ªÉn ƒë·ªïi t·ª´ normed_sexpression sang s_expression.
    Sau khi tr√≠ch xu·∫•t c√°c nh√£n quan h·ªá v√† th·ª±c th·ªÉ t·ª´ normed_expr,
    ta l·∫•y danh s√°ch c√°c m√£ ·ª©ng vi√™n cho m·ªói nh√£n v√† t·∫°o ho√°n v·ªã gi·ªØa c√°c c·∫∑p ·ª©ng vi√™n ƒë√≥.
    K·∫øt qu·∫£ tr·∫£ v·ªÅ l√† m·ªôt danh s√°ch c√°c s_expression kh·∫£ dƒ©.
    """
    normed_expr = fix_unbalanced_parentheses(normed_expr)
    # Tr√≠ch xu·∫•t c√°c nh√£n quan h·ªá v√† th·ª±c th·ªÉ t·ª´ bi·ªÉu th·ª©c
    relations, entities = extract_entities_and_relations(normed_expr)
    
    # T·∫°o mapping t·ª´ token xu·∫•t hi·ªán trong bi·ªÉu th·ª©c sang danh s√°ch c√°c ·ª©ng vi√™n m√£.
    # V√≠ d·ª•: token_str = "[ author ]"
    candidate_map = {}
    for rel in relations:
        token = f'[ {rel} ]'
        candidate = find_relation(rel, gold_relation_map, relation_KB_map, facc1_index)
        # N·∫øu candidate kh√¥ng ph·∫£i danh s√°ch, chuy·ªÉn n√≥ th√†nh danh s√°ch ƒë·ªÉ t·∫°o ho√°n v·ªã
        if not isinstance(candidate, list):
            candidate = [candidate]
        candidate_map[token] = candidate
        
    for ent in entities:
        token = f'[ {ent} ]'
        candidate = find_entity(ent, gold_entity_map , facc1_index)
        if not isinstance(candidate, list):
            candidate = [candidate]
        candidate_map[token] = candidate
    
    # N·∫øu kh√¥ng c√≥ token n√†o c·∫ßn thay th·∫ø, tr·∫£ v·ªÅ bi·ªÉu th·ª©c g·ªëc
    if not candidate_map:
        return [normed_expr]
    
    # L·∫•y danh s√°ch c√°c token v√† danh s√°ch c√°c danh s√°ch ·ª©ng vi√™n t∆∞∆°ng ·ª©ng
    tokens = list(candidate_map.keys())
    candidate_lists = [candidate_map[token] for token in tokens]
    
    # T·∫°o t·∫•t c·∫£ c√°c ho√°n v·ªã ·ª©ng vi√™n (Cartesian product)
    all_combinations = list(product(*candidate_lists))
    
    s_expressions = []
    for comb in all_combinations:
        temp_expr = normed_expr
        # V·ªõi m·ªói token, thay th·∫ø b·∫±ng ·ª©ng vi√™n t∆∞∆°ng ·ª©ng theo ho√°n v·ªã
        for token, replacement in zip(tokens, comb):
            temp_expr = temp_expr.replace(token, replacement)
        s_expressions.append(temp_expr)
    
    return s_expressions

In [11]:

def load_jsonl(file_path):
    """ƒê·ªçc file JSONL v√† tr·∫£ v·ªÅ danh s√°ch c√°c object"""


    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read().replace('(EXPECTED RESULT)', 'null').replace('(QUESTION)', 'null')

    try:
        data = ujson.loads(content)
        print("JSON loaded successfully!")
    except Exception as e:
        print(f"Error: {e}")
    return data


In [12]:
def calculate_prf1(gold_answers, pred_answers):
    """T√≠nh Precision, Recall, F1-score"""
    if len(gold_answers) == 0:
        if len(pred_answers) == 0:
            return [1.0, 1.0, 1.0]  # ƒê√∫ng khi kh√¥ng c√≥ c√¢u tr·∫£ l·ªùi
        else:
            return [0.0, 1.0, 0.0]
    elif len(pred_answers) == 0:
        return [0.0, 0.0, 0.0]
    tp = 1e-40  # numerical trick

    tp = tp + len(set(gold_answers) & set(pred_answers))

    fp = len(set(pred_answers) - set(gold_answers))
    fn = len(set(gold_answers) - set(pred_answers))
    precision = tp / (tp + fp) 
    recall = tp / (tp + fn) 
    f1 = (2 * precision * recall) / (precision + recall) 

    return [precision, recall, f1]


In [13]:
# ƒê·ªçc d·ªØ li·ªáu
predictions = load_jsonl("LLMs/beam_prediction/generated_predictions_beam_Q7b.json")
gold_data = load_jsonl("Data/LC-QuAD2.0/label_map/LC-QuAD2.0_test.json")
gold_data = gold_data[:1000]

JSON loaded successfully!
JSON loaded successfully!


### Load danh s√°ch quan h·ªá trong KB

In [14]:
file_path = "property_list_file_wikidata_complete_all_mention"

relation_KB_map = {}

# ƒê·ªçc file v√† ƒë·∫£o ng∆∞·ª£c quan h·ªá
with open(file_path, "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split("\t")  # T√°ch theo tab
        if len(parts) == 2:
            relation_KB_map[parts[1]] = parts[0]

In [15]:
ex_cnt = 0
top_hit = 0
failed_preds = []
final_executable_cnt = 0

In [16]:
import json
import time
from tqdm import tqdm

BATCH_SIZE = 300  # S·ªë d√≤ng m·ªói batch
LOG_FILE = "progress_log.txt"

results = []
simcse_model = model  # M√¥ h√¨nh SimCSE c·ªßa b·∫°n
facc1_index = surface_index
parser = SExpressionParser()
start_time = time.perf_counter()  # B·∫Øt ƒë·∫ßu ƒëo th·ªùi gian

# Ghi log
with open(LOG_FILE, "w", encoding="utf-8") as log_file:
    log_file.write("B·∫Øt ƒë·∫ßu qu√° tr√¨nh ƒë√°nh gi√°\n")

for i, (pred, gold) in enumerate(tqdm(zip(predictions, gold_data), total=len(predictions), desc="ƒêang ƒë√°nh gi√°")):
    try:
        gold_entity_map = {v.lower(): k for k, v in gold['gold_entity_map'].items()}
        gold_relation_map = {v.lower(): k for k, v in gold['gold_relation_map'].items()}
        gold_answers = gold.get("answer", [])
        executable_index = None
        best_f1, best_precision, best_recall = 0, 0, 0
        kq = []
        lag_result = False
        denormed_pred = []
        for rank, query in enumerate(pred['predicted_query']):
            if lag_result:
                break

            set_query = convert_normed_to_s_expression(query, gold_relation_map, gold_entity_map, relation_KB_map , facc1_index)
            for q in set_query:
                query_result = []
                if not q:
                    continue
                if rank == 0 and q.lower() ==gold['s_expr'].lower():
                    ex_cnt +=1
                sparql = convert_s_expression_to_sparql(q)
                if sparql == "UNKNOWN":
                    continue
                denormed_pred.append(sparql)
                query_result = execute_query_with_odbc(sparql)
                if query_result:
                    if rank == 0:
                        top_hit += 1
                    executable_index = rank
                    precision, recall, f1 = calculate_prf1(gold_answers, query_result)
                    if f1 > best_f1:
                        kq = query_result
                        best_f1, best_precision, best_recall = f1, precision, recall
                    if precision == 1:
                        lag_result = True
                        break
        if executable_index is not None:
            final_executable_cnt+=1
        else:
            failed_preds.append({'qid':gold["question_id"], 
                'gt_sexpr': gold['s_expr'], 
                'gt_normed_sexpr': pred['gen_label'],
                'pred': pred, 
                'denormed_pred':denormed_pred})    
        results.append({
            "qid": gold["question_id"],
            "answer": gold_answers,
            "result": kq,
            "nor_s_expr":gold["nor_s_expr"],
            "precision": best_precision,
            "recall": best_recall,
            "f1": best_f1
        })

        # Ghi log ti·∫øn tr√¨nh x·ª≠ l√Ω
        with open(LOG_FILE, "a", encoding="utf-8") as log_file:
            log_file.write(f"ƒê√£ x·ª≠ l√Ω xong d√≤ng {i + 1}/{len(predictions)}\n")

        # Khi ƒë·ªß 100 k·∫øt qu·∫£, l∆∞u v√†o file v√† reset bi·∫øn `results`
        if (i + 1) % BATCH_SIZE == 0 or (i + 1) == len(predictions):
            batch_id = (i + 1) // BATCH_SIZE
            filename = f"LLMs/eval_result/evaluation_vinallamaQ7b_part_{batch_id}.json"
            with open(filename, "w", encoding="utf-8") as f:
                json.dump(results, f, indent=4)
            print(f"‚úÖ ƒê√£ l∆∞u {len(results)} d√≤ng v√†o {filename}")
            results = []  # Reset danh s√°ch k·∫øt qu·∫£

    except Exception as e:
        with open(LOG_FILE, "a", encoding="utf-8") as log_file:
            log_file.write(f"L·ªói t·∫°i d√≤ng {i + 1}: {str(e)}\n")
        print(f"‚ùå L·ªói t·∫°i d√≤ng {i + 1}: {e}")

# K·∫øt th√∫c ƒëo th·ªùi gian
end_time = time.perf_counter()
total_time = end_time - start_time
print(f"üéØ Qu√° tr√¨nh ƒë√°nh gi√° ho√†n t·∫•t trong {total_time:.2f} gi√¢y!")

with open(LOG_FILE, "a", encoding="utf-8") as log_file:
    log_file.write(f"Qu√° tr√¨nh ƒë√°nh gi√° ho√†n t·∫•t trong {total_time:.2f} gi√¢y!\n")


ƒêang ƒë√°nh gi√°:   0%|          | 0/1000 [05:12<?, ?it/s]


KeyboardInterrupt: 

In [None]:
print('STR Match', ex_cnt/ len(predictions))
print('TOP 1 Executable', top_hit/ len(predictions))
print('Final Executable', final_executable_cnt/ len(predictions))

STR Match 0.34690943938667945
TOP 1 Executable 0.45759463344513657
Final Executable 0.6411116435074269


In [None]:
STR_Match = ex_cnt/ len(predictions)
TOP1_Executable = top_hit/ len(predictions)
Final_Executable = final_executable_cnt/ len(predictions)

In [None]:
import glob

# Danh s√°ch c√°c file k·∫øt qu·∫£ ƒë√°nh gi√°
file_list = glob.glob("LLMs/eval_result/evaluation_vinallamaQ7b_part_*.json")

# Bi·∫øn ƒë·ªÉ t·ªïng h·ª£p k·∫øt qu·∫£
total_precision = 0
total_recall = 0
total_f1 = 0
total_samples = 0

# ƒê·ªçc t·ª´ng file v√† t·ªïng h·ª£p d·ªØ li·ªáu
for file in file_list:
    with open(file, "r", encoding="utf-8") as f:
        data = json.load(f)
        for entry in data:
            total_precision += entry.get("precision", 0)
            total_recall += entry.get("recall", 0)
            total_f1 += entry.get("f1", 0)
            total_samples += 1

# Tr√°nh chia cho 0
if total_samples > 0:
    avg_precision = total_precision / total_samples
    avg_recall = total_recall / total_samples
    avg_f1 = total_f1 / total_samples
else:
    avg_precision, avg_recall, avg_f1, hits_at_1 = 0, 0, 0, 0

# Hi·ªÉn th·ªã k·∫øt qu·∫£
print(f"üìä Precision: {avg_precision:.5f}")
print(f"üìä Recall: {avg_recall:.5f}")
print(f"üìä F1-score: {avg_f1:.5f}")


üìä Precision: 0.51899
üìä Recall: 0.53568
üìä F1-score: 0.51993


In [None]:
# T·∫°o d·ªØ li·ªáu k·∫øt qu·∫£ ƒë√°nh gi√°
eval_results = {
    "precision": avg_precision,
    "recall": avg_recall,
    "f1": avg_f1,
    "STR Match": STR_Match,
    "Hit@1": TOP1_Executable,
    "Final_Executable": Final_Executable, 
}

# L∆∞u v√†o file JSON
output_file = "LLMs/eval_result/Final_evaluation_Q7b.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(eval_results, f, indent=2, ensure_ascii=False)

print(f"‚úÖ K·∫øt qu·∫£ ƒë√£ ƒë∆∞·ª£c l∆∞u v√†o {output_file}")

‚úÖ K·∫øt qu·∫£ ƒë√£ ƒë∆∞·ª£c l∆∞u v√†o LLMs/eval_result/Final_evaluation_27b.json


In [None]:

output_file = "LLMs/eval_result/Failed_result_Q7b.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(failed_preds, f, indent=2, ensure_ascii=False)

print(f"‚úÖ K·∫øt qu·∫£ ƒë√£ ƒë∆∞·ª£c l∆∞u v√†o {output_file}")

‚úÖ K·∫øt qu·∫£ ƒë√£ ƒë∆∞·ª£c l∆∞u v√†o LLMs/eval_result/Failed_result_27b.json
