In [None]:
!pip install unsloth
!pip install sentence-transformers numpy

In [None]:
#database
!gdown 1vqmL82yiPhgAOCF43gSBJCZ2nlrZRZrL
#top10_articles
!gdown 1h7FTRHfl2KJm8k8mfMsCFSN6nTdo-H1W 
#base_captions
!gdown 1GqxM2uf5y_bMrCzIMMMAMyenRHPA6Q6p

In [None]:
import unsloth
import transformers
import numpy as np
from transformers import TextStreamer
import json
import pandas as pd
import torch
import logging
import os
from tqdm import tqdm

In [None]:

TOP10_ARTICLES_FILE = "top10_articles.csv"
JSON_DB_FILE = "database_preprocessed.json"
BASE_CAPTIONS_FILE = "base_captions_private.csv"


CHECKPOINT_DIR = r"checkpoint"
CHECKPOINT_FILE = os.path.join(CHECKPOINT_DIR, "checkpoint.json")


OUTPUT_DIR = "ReZeroSlavery"
FINAL_OUTPUT_JSON = os.path.join(OUTPUT_DIR, "ReZeroSlavery.json")
FINAL_OUTPUT_CSV = os.path.join(OUTPUT_DIR, "ReZeroSlavery.csv")

MODEL_PATH = "checkpoint-200" #Adjust your model path here

In [None]:
sentence_model_global = None
MODEL_NAME_SEMANTIC = 'all-MiniLM-L12-v2'
from sentence_transformers import SentenceTransformer
sentence_model_global = SentenceTransformer(MODEL_NAME_SEMANTIC)

In [None]:

logging.basicConfig(
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO
)

In [None]:
import re
def _identify_topic(text: str, titles: list) -> str:
    all_text = ' '.join(titles).lower() + ' ' + text.lower()
    topic_keywords = {
        'technology': ['technology', 'tech', 'software', 'AI', 'robot', 'digital', 'computer', 'innovation', 'platform', 'data'],
        'business': ['business', 'company', 'market', 'economy', 'trade', 'finance', 'investment', 'ceo', 'gdp', 'stock', 'enterprise'],
        'politics': ['election', 'president', 'government', 'policy', 'political', 'minister', 'congress', 'parliament', 'senate', 'legislation', 'bill', 'campaign'],
        'sports': ['game', 'player', 'team', 'match', 'championship', 'athlete', 'sport', 'win', 'tournament', 'olympics', 'score'],
        'health': ['health', 'medical', 'doctor', 'patient', 'disease', 'treatment', 'hospital', 'vaccine', 'pandemic', 'healthcare', 'medicine'],
        'environment': ['climate', 'environment', 'pollution', 'renewable', 'energy', 'sustainable', 'carbon', 'green', 'emissions', 'ecology'],
        'entertainment': ['movie', 'film', 'actor', 'music', 'artist', 'show', 'entertainment', 'celebrity', 'concert', 'awards'],
        'science': ['research', 'study', 'scientist', 'discovery', 'experiment', 'science', 'data', 'analysis', 'journal', 'university'],
        'education': ['education', 'school', 'university', 'college', 'student', 'teacher', 'learning', 'curriculum'],
        'social issues': ['social', 'community', 'human rights', 'inequality', 'poverty', 'justice', 'protest'],
        'world affairs': ['international', 'global', 'world', 'geopolitics', 'diplomacy', 'conflict', 'united nations'],
        'food': ['food', 'restaurant', 'chef', 'meal', 'cuisine', 'cooking', 'recipe', 'dining', 'agriculture'],
        'travel': ['travel', 'tourism', 'destination', 'flight', 'hotel', 'vacation', 'journey', 'trip', 'airport']
    }
    topic_scores = {}
    for topic, keywords in topic_keywords.items():
        score = sum(1 for keyword in keywords if keyword in all_text)
        if score > 0:
            topic_scores[topic] = score
    if topic_scores:
        return max(topic_scores, key=topic_scores.get)
    return 'general'

def _extract_organizations(text: str) -> list:
    organizations = []
    patterns = [
        r'\b[A-Z]{2,6}\b', 
        r'\b[A-Z][a-zA-Z]*(?:\s+(?:and|of|the|for)\s+)?[A-Z][a-zA-Z]*(?:\s+(?:Inc|Corp|Ltd|LLC|Co|Group|Holdings|Foundation|Association|Organization|Agency|Department|University|Institute|College|School|Council|Committee|Party|Union|Bank|Studio|Network))\b',
        r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\s+(?:Inc\.|Corp\.|Ltd\.|LLC|Co\.|Group|Foundation|Association|Organization|Agency|Department|University|Institute|College|School|Council|Committee|Party|Union|Bank|Studio|Network)\b',
    ]
    
    known_orgs = [
        'Google', 'Microsoft', 'Apple', 'Amazon', 'Meta', 'Facebook', 'Twitter', 'Netflix', 'Tesla',
        'United Nations', 'World Health Organization', 'European Union', 'NATO', 'NASA',
        'CNN', 'BBC', 'Reuters', 'Associated Press', 'New York Times', 'The Guardian',
    ]

    for pattern in patterns:
        matches = re.findall(pattern, text)
        organizations.extend(matches)

    for org in known_orgs:
        if re.search(r'\b' + re.escape(org) + r'\b', text, re.IGNORECASE):
            organizations.append(org)

    
    processed_orgs = []
    common_words_in_org_names = {'The', 'A', 'An', 'Of', 'And', 'For'} 
    for org in organizations:
        org_stripped = org.strip()
        if len(org_stripped) <= 1 and org_stripped.isupper():
            continue
        if org_stripped.isupper() and len(org_stripped) > 6: 
             if org_stripped not in known_orgs: 
                continue
        if org_stripped in common_words_in_org_names:
            continue
        processed_orgs.append(org_stripped)

    final_orgs = list(set(processed_orgs))
    final_orgs.sort(key=lambda x: (len(x.split()), x.isupper()), reverse=True)
    return final_orgs[:8]


def _extract_people(text: str) -> list:
    people = []
    patterns = [
        r'\b(?:Mr\.|Mrs\.|Ms\.|Miss|Dr\.|Prof\.|President|CEO|Minister|Director|Ambassador|General|Captain|Chef|Senator|Governor|Mayor|Councillor|Judge|Justice)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,2})\b',
        
        r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z\'\-]+){1,3})\b' 
    ]

    for pattern in patterns:
        matches = re.findall(pattern, text)
        if isinstance(matches, list) and matches and isinstance(matches[0], tuple): 
             people.extend([m[0] for m in matches if m[0]]) 
        else:
             people.extend(matches)

    
    non_name_keywords = {
        'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December',
        'Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday',
        'Today', 'Yesterday', 'Tomorrow', 'Week', 'Month', 'Year',
        'Street', 'Road', 'Avenue', 'City', 'State', 'Country', 'County', 'Park', 'Building', 'Center', 'Plaza', 'Square',
        'The', 'This', 'That', 'These', 'Those', 'And', 'But', 'For', 'With', 'From', 'About', 'Under', 'Over',
        'Is', 'Are', 'Was', 'Were', 'Has', 'Have', 'Had', 'Says', 'Said', 'Told',
        'North', 'South', 'East', 'West', 
        'Company', 'Corporation', 'Incorporated', 'Limited', 'Organization', 'Department', 'University', 'Institute', 'College', 'School' 
    }

    processed_people = []
    for p_match in people:
        p = p_match.strip()
        words = p.split()
        if len(words) >= 2 and all(word[0].isupper() for word in words) and not all(word.isupper() for word in words) and not any(word in non_name_keywords for word in words) and len(p)>3 :
            processed_people.append(p)
        elif len(words) == 1 and p[0].isupper() and p not in non_name_keywords and len(p)>3 and not p.isupper(): 
            processed_people.append(p)


    
    final_people = list(set(processed_people))
    final_people.sort(key=len, reverse=True)
    return final_people[:8]

def _extract_locations(text: str) -> list:
    locations = []
    
    predefined_locations = [
        
        'Vietnam', 'United States', 'China', 'India', 'Japan', 'Germany', 'United Kingdom', 'France', 'Canada', 'Australia', 'Russia', 'Brazil', 'South Korea', 'Italy', 'Spain',
        'New York', 'Los Angeles', 'Chicago', 'London', 'Paris', 'Berlin', 'Tokyo', 'Beijing', 'Shanghai', 'Seoul', 'Moscow', 'Singapore', 'Sydney', 'Toronto', 'Rome', 'Madrid', 'Washington D.C.'
        'San Francisco', 'Silicon Valley' # Vùng
    ]
    for loc in predefined_locations:
        if re.search(r'\b' + re.escape(loc) + r'\b', text, re.IGNORECASE):
            locations.append(loc)

    patterns = [
        r'\b([A-Z][a-zA-Z\']+)(?:\s+(?:of|de|the|la)\s+)?(?:[A-Z][a-zA-Z\']+){0,3}(?:,\s*[A-Z][a-zA-Z\.\s]+)?\b'
    ]
    for pattern in patterns:
        matches = re.findall(pattern, text)
        locations.extend(match.strip() for match in matches if len(match.strip()) > 2) 

    non_location_keywords = _extract_people(text) + _extract_organizations(text) + [ 
        'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December',
        'Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday', 'Mr', 'Ms', 'Dr'
    ]
    non_location_keywords_lower = {k.lower() for k in non_location_keywords}

    processed_locations = []
    for loc in locations:
        loc_stripped = loc.strip().rstrip(',.')
        if len(loc_stripped.split()) == 1 and (loc_stripped.isdigit() or (loc_stripped.isupper() and loc_stripped not in ['US', 'UK', 'EU'])):
            continue
        if loc_stripped.lower() not in non_location_keywords_lower and len(loc_stripped)>2:
            if not (loc_stripped.lower().startswith("the ") and len(loc_stripped.split()) < 3): 
                 processed_locations.append(loc_stripped)


    final_locations = list(set(processed_locations))
    final_locations.sort(key=len, reverse=True) 
    return final_locations[:8]


def _extract_events(text: str) -> list:
    events = []
    
    patterns = [
        r'\b(?:the\s+)?([A-Z][a-zA-Z0-9\s\'\-]+(?:Conference|Summit|Forum|Meeting|Festival|Games|Olympics|Championship|Cup|Awards|Exhibition|Show|Ceremony|Campaign|Initiative|Project|Program|Operation|War|Battle|Treaty|Accord|Act|Bill|Law|Debate|Election|Crisis|Pandemic|Outbreak|Attack|Incident|Disaster))\b',
        r'\b([A-Z][a-zA-Z]+\s+(?:World Cup|Olympic Games|Grand Prix|Open|Summit|Conference|Festival))\b', 
        r'\b(?:G7 Summit|G20 Summit|COP\d+\sConference)\b', 
        r'\b\d{4}\s+(?:Summer|Winter)\s+Olympics\b', 
        
    ]
    
    known_events = ['World War I', 'World War II', 'Vietnam War', 'Cold War', 'September 11 attacks', 'COVID-19 Pandemic']

    for pattern in patterns:
        matches = re.findall(pattern, text)
        
        events.extend([m if isinstance(m, str) else m[0] for m in matches])

    for event in known_events:
        if re.search(r'\b' + re.escape(event) + r'\b', text, re.IGNORECASE):
            events.append(event)

    processed_events = [event.strip().rstrip(',.') for event in events if len(event.strip()) > 4]
    final_events = list(set(processed_events))
    final_events.sort(key=len, reverse=True)
    return final_events[:5]

def _extract_dates(text: str, provided_date: str = None) -> list:
    dates = []
    if provided_date:
        dates.append(provided_date.strip())

    patterns = [
        
        r'\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,\s+\d{4}\b',
        r'\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{4}\b',
        
        r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b',
        r'\b\d{4}[/-]\d{1,2}[/-]\d{1,2}\b',
        
        r'\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{4}\b',
        
        r'\b(?:Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday)(?:,\s*(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,\s+\d{4})?\b',
        
        r'\b(?:yesterday|today|tomorrow|last\s+week|next\s+week|last\s+month|next\s+month|this\s+year|last\s+year|next\s+year)\b',
        
        r'\b(?:in|during|on|by|since|until|from|the\s+year\s+of)\s+(\d{4})\b',
        r'\b(\d{4})\b' 
    ]
    for pattern in patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        if pattern.endswith(r'\b(\d{4})\b') or pattern.endswith(r'\s+(\d{4})\b') :
            dates.extend(m for m in matches if 1900 <= int(m) <= 2050) 
        else:
            dates.extend(m.strip() for m in matches)


    
    final_dates = []
    current_year = 2025 
    for d_match in dates:
        d = d_match.strip().rstrip(',.')
        if d.isdigit() and len(d) == 4:
            year = int(d)
            if 1900 <= year <= current_year + 5: 
                final_dates.append(d)
        elif len(d) > 3: 
            final_dates.append(d)

    final_dates = list(set(final_dates))
    final_dates.sort(key=lambda x: (len(x), x), reverse=True)
    return final_dates[:5]

def _extract_key_terms(text: str, title:str) -> list:
    combined_text = (title.lower() + " ") * 3 + text.lower() 

    
    text_no_urls = re.sub(r'http\S+|www.\S+|\S+@\S+', '', combined_text)

    
    words = re.findall(r'\b[a-zA-Z0-9][a-zA-Z0-9\-\']*[a-zA-Z0-9]\b', text_no_urls)

    
    stop_words = {
        'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'being', 'been', 'this', 'that', 'these', 'those',
        'and', 'or', 'but', 'if', 'of', 'at', 'by', 'for', 'with', 'about', 'to', 'from', 'in', 'out', 'on',
        'it', 'its', 'he', 'she', 'they', 'them', 'his', 'her', 'their', 'you', 'your', 'we', 'our',
        'i', 'me', 'my', 'mine', 'us', 'ours', 'myself', 'yourself', 'himself', 'herself', 'itself', 'ourselves', 'yourselves', 'themselves',
        'what', 'which', 'who', 'whom', 'whose', 'why', 'how', 'when', 'where',
        'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very',
        'can', 'will', 'just', 'don', 'should', 'now', 'do', 'does', 'did', 'doing', 'said', 'says', 'also', 'get', 'go', 'make', 'know', 'see', 'use', 'find', 'tell', 'ask', 'work', 'seem', 'feel', 'try', 'leave', 'call',
        'one', 'two', 'three', 'january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december',
        'mr', 'mrs', 'ms', 'dr', 'prof', 'inc', 'ltd', 'corp', 
        
        'news', 'report', 'story', 'article', 'image', 'photo', 'picture', 'video', 'caption', 'description',
        'people', 'person', 'man', 'woman', 'child', 'children', 'group', 'team',
        'world', 'country', 'city', 'government', 'company', 'organization', 'event', 'system', 'part', 'number', 'way', 'thing', 'day', 'year', 'time', 'today', 'content', 'information', 'context', 'detail', 'example'
    }

    
    proper_nouns_phrases = re.findall(r'\b[A-Z][a-zA-Z0-9\-\']*(?:\s+[A-Z][a-zA-Z0-9\-\']*){0,3}\b', text) 
    filtered_proper_nouns = []
    for phrase in proper_nouns_phrases:
        p_words = phrase.split()
        if not all(word.lower() in stop_words for word in p_words) and \
           not (len(p_words) == 1 and p_words[0].lower() in stop_words) and \
           len(phrase.strip()) > 2 :
            filtered_proper_nouns.append(phrase.strip())


    filtered_words = [word for word in words if word not in stop_words and len(word) > 2 and not word.isdigit()]

    term_freq = {}
    
    for term in filtered_proper_nouns + title.lower().split():
        if term.lower() not in stop_words and len(term)>2:
            term_freq[term.lower()] = term_freq.get(term.lower(), 0) + 2 

    for word in filtered_words:
        term_freq[word] = term_freq.get(word, 0) + 1

    
    sorted_terms = sorted(term_freq.items(), key=lambda x: (x[1], len(x[0].split()), len(x[0])), reverse=True)

    final_terms = []
    seen_lower = set()
    for term, freq in sorted_terms:
        if term not in seen_lower: 
            original_case_term = term
            for pn in filtered_proper_nouns:
                if pn.lower() == term:
                    original_case_term = pn
                    break
            final_terms.append(original_case_term)
            seen_lower.add(term)
        if len(final_terms) >= 10:
            break

    return final_terms


def _extract_numbers(text: str) -> list:
    patterns = [
        
        r'(?:\$|€|£|¥|USD|EUR|GBP|JPY|VND)\s*\d+(?:[.,]\d{3})*(?:[.,]\d+)?(?:\s*(?:million|billion|trillion|thousand|K|M|B|T))?\b',
        r'\b\d+(?:[.,]\d{3})*(?:[.,]\d+)?\s*(?:dollars?|euros?|pounds?|yen|đồng|USD|EUR|GBP|JPY|VND)(?:\s*(?:million|billion|trillion|thousand|K|M|B|T))?\b',
        
        r'\b\d+(?:[.,]\d+)?\s*%(?:\s*points)?\b',
        r'\b\d+(?:[.,]\d+)?\s*(?:percent|per\s+cent|percentage\s+points?)\b',
        
        r'\b\d+(?:[.,]\d{3})*(?:[.,]\d+)?\s*(?:people|users|viewers|votes|cases|deaths|infections|jobs|companies|countries|cities|members|students|teachers|schools|hospitals|doctors|patients|items|products|services|cars|houses|buildings|acres|hectares|tons|kg|grams|liters|gallons|km|kilometers|meters|miles|feet|gb|mb|tb|hz|watts|volts|degrees|°C|°F|points|barrels|shares|pages|chapters|articles|sections|votes)\b',
        
        r'\b\d+(?:[.,]\d+)?\s*(?:to|-|–)\s*\d+(?:[.,]\d+)?\b',
        
        r'\b\d{1,3}(?:[.,]\d{3})*(?:[.,]\d+)?\b',
        
        r'\b(?:age|aged)\s+\d+\b',
        r'\b\d+\s*years?\s*old\b',
    ]
    numbers = []
    for pattern in patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        numbers.extend(matches)


    
    processed_numbers = []
    for num_match in numbers:
        num_str = num_match.strip()
        
        try:
            
            if re.fullmatch(r'\d+', num_str) and int(num_str) < 10: 
                if not any(unit in text[text.find(num_str):text.find(num_str)+len(num_str)+10].lower() for unit in ['million', 'billion', 'thousand', '%', 'percent', 'degree']): # Kiểm tra ngữ cảnh gần
                    continue
        except ValueError:
            pass 
        processed_numbers.append(num_str)

    final_numbers = list(set(processed_numbers))
    final_numbers.sort(key=lambda x: any(kw in x.lower() for kw in ['million', 'billion', 'trillion', 'percent', '%', '$', '€', '£']), reverse=True) # Ưu tiên số có đơn vị lớn/tiền tệ/phần trăm
    return final_numbers[:10]


def extract_key_info(article) -> dict:
        """Extract key information from the top 1 article (most relevant) - read complete content"""
        if not article:
            return {
                'titles': [],
                'sources': [],
                'topic': 'general',
                'organizations': [],
                'people': [],
                'locations': [],
                'events': [],
                'numbers': [],
                'dates': [],
                'key_terms': [],
                'context': ''
            }
        
        title = article.get('title', '')
        # Read the FULL article content, not truncated
        full_content = article.get('content', '')
        # print(f"📖 Reading complete article content: {len(full_content)} characters")  # Commented out
        
        # Extract source
        sources = []
        if article.get('url'):
            if 'cnn.com' in article['url']:
                sources.append('CNN')
            elif 'guardian.com' in article['url']:
                sources.append('The Guardian')
        
        # Use complete text for analysis (title + full article content)
        all_text = title + ' ' + full_content
        
        # Extract information from the complete article
        info = {
            'titles': [title] if title else [],
            'sources': sources,
            'topic': _identify_topic(all_text, [title]),
            'organizations': _extract_organizations(all_text),
            'people': _extract_people(all_text),
            'locations': _extract_locations(all_text),
            'events': _extract_events(all_text),
            'numbers': _extract_numbers(all_text),
            'dates': _extract_dates(all_text),
            'key_terms': _extract_key_terms(all_text, title),
            'context': full_content  # Store the complete article content
        }
        
        return info

In [None]:
# SEMANTIC SEARCH
def _semantic_article_extraction(full_content: str, base_caption: str, key_info: dict) -> str:

    try:
        
        sentences = re.split(r'(?<=[.!?])\s+', full_content.strip())
        sentences = [s.strip() for s in sentences if len(s.strip()) > 15] 

        if not sentences or len(sentences) < 5 : 
            return full_content[:5000] + "..." if len(full_content) > 5000 else full_content

        
        chunk_size = 3
        overlap = 1
        chunks = []
        
        for i in range(0, len(sentences) - chunk_size + 1, chunk_size - overlap):
            chunk = ' '.join(sentences[i : i + chunk_size])
            if chunk.strip():
                chunks.append(chunk)
                # chunk_indices.append(list(range(i, i + chunk_size)))

        if not chunks:
            return full_content[:5000] + "..." if len(full_content) > 5000 else full_content

        
        search_queries = [base_caption]
        if key_info.get('titles'): search_queries.append(f"Title context: {key_info['titles'][0]}")
        if key_info.get('key_terms'): search_queries.append(f"Key terms: {', '.join(key_info['key_terms'][:3])}")
        if key_info.get('people'): search_queries.append(f"People involved: {', '.join(key_info['people'][:2])}")
        if key_info.get('organizations'): search_queries.append(f"Organizations: {', '.join(key_info['organizations'][:2])}")
        if key_info.get('locations'): search_queries.append(f"Locations: {', '.join(key_info['locations'][:2])}")


        chunk_embeddings = sentence_model_global.encode(chunks, show_progress_bar=False, batch_size=128)
        query_embeddings = sentence_model_global.encode(search_queries, show_progress_bar=False, batch_size=128)

        
        all_sim_scores = []
        for query_emb in query_embeddings:
            sim_scores = np.dot(chunk_embeddings, query_emb) / (np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_emb))
            all_sim_scores.append(sim_scores)

        combined_scores = np.max(np.array(all_sim_scores), axis=0)


        num_top_chunks = min(max(5, int(len(chunks) * 0.3)), 10) 

        
        top_indices_by_score = np.argsort(combined_scores)[-num_top_chunks:]



        selected_chunks_with_scores = []
        for idx in top_indices_by_score:
            selected_chunks_with_scores.append((idx, chunks[idx], combined_scores[idx])) # (original_chunk_idx, chunk_text, score)

        
        selected_chunks_with_scores.sort(key=lambda x: x[0])

        
        relevant_text_parts = [chunk_data[1] for chunk_data in selected_chunks_with_scores]

        
        final_relevant_sentences = set()
        
        for part in relevant_text_parts:
            for s_in_chunk in re.split(r'(?<=[.!?])\s+', part.strip()):
                 if s_in_chunk.strip():
                    final_relevant_sentences.add(s_in_chunk.strip())


        
        first_sentences = [s.strip() for s in sentences[:min(3, len(sentences))] if s.strip()]
        last_sentences = [s.strip() for s in sentences[max(0, len(sentences)-2):] if s.strip()]

        
        combined_set = set(first_sentences) | final_relevant_sentences | set(last_sentences)


        
        sentence_to_original_index = {sentence: i for i, sentence in enumerate(sentences)}

        sorted_combined_sentences = sorted(list(combined_set), key=lambda s: sentence_to_original_index.get(s, float('inf')))


        extracted_text = ' '.join(sorted_combined_sentences)

        
        max_len = 6000 
        if len(extracted_text) > max_len:
            extracted_text = extracted_text[:max_len]
            last_sentence_end = extracted_text.rfind('.')
            if last_sentence_end > 0:
                extracted_text = extracted_text[:last_sentence_end+1]
            else: 
                extracted_text += "..."

        return extracted_text

    except Exception as e:
        print(f"Error when semantic searching: {e}")
        import traceback
        traceback.print_exc()
        return full_content[:5000] + "..." if len(full_content) > 5000 else full_content

In [None]:
USE_SEMANTIC_SEARCH = True
def create_advanced_prompt(base_caption: str, key_info: dict) -> str:
    topic = key_info.get('topic', 'general')
    source_str = "" 

    organizations = key_info.get('organizations', [])[:3] 
    people = key_info.get('people', [])[:3]
    locations = key_info.get('locations', [])[:3]
    events = key_info.get('events', [])[:2]
    numbers = key_info.get('numbers', [])[:3] 
    dates = key_info.get('dates', [])[:3] 

    full_context = key_info.get('context', '')
    
    article_title = key_info.get('titles', [''])[0]

    context_elements = []
    if article_title:
        context_elements.append(f"MAIN STORY: {article_title}")
    if organizations:
        context_elements.append(f"KEY ORGANIZATIONS: {', '.join(organizations)}")
    if people:
        context_elements.append(f"PEOPLE INVOLVED: {', '.join(people)}")
    if locations:
        context_elements.append(f"LOCATIONS: {', '.join(locations)}")
    if events:
        context_elements.append(f"EVENTS: {', '.join(events)}")
    if numbers:
        context_elements.append(f"KEY FIGURES: {', '.join(numbers)}")
    if dates:
        context_elements.append(f"TIMELINE: {', '.join(dates)}")

    article_summary = ""

    if full_context:
        if USE_SEMANTIC_SEARCH and sentence_model_global and len(full_context) > 3000:
            article_summary = _semantic_article_extraction(full_context, base_caption, key_info)
        else:
            article_length = len(full_context)
            if article_length <= 3000:
                # Short article: use full content
                article_summary = full_context
            elif article_length <= 8000:
                # Medium article: smart sampling with higher density
                sentences = full_context.split('. ')
                total_sentences = len(sentences)

                # Take more sentences for better coverage
                key_sentences = []

                # First 10 sentences (usually most important)
                key_sentences.extend(sentences[:10])

                # Every 3rd sentence from the middle section
                middle_start = 8
                middle_end = total_sentences - 4
                for i in range(middle_start, middle_end, 3):
                  if i < total_sentences:
                    key_sentences.append(sentences[i])

                # Last 4 sentences (conclusions, outcomes)
                key_sentences.extend(sentences[-4:])

                    # Remove duplicates while preserving order
                seen = set()
                unique_sentences = []
                for sentence in key_sentences:
                  if sentence.strip() and sentence not in seen:
                    seen.add(sentence)
                    unique_sentences.append(sentence)

                article_summary = '. '.join(unique_sentences)
            else:
              article_summary = full_context[:5000]

              # Add key sentences from the rest
              remaining_content = full_context[5000:]
              remaining_sentences = remaining_content.split('. ')

              # Add every 5th sentence from remaining content
              additional_sentences = []
              for i in range(0, len(remaining_sentences), 5):
                if len(additional_sentences) < 20:  # Limit additional sentences
                  additional_sentences.append(remaining_sentences[i])

              if additional_sentences:
                article_summary += ". " + '. '.join(additional_sentences)


        max_length = 6000  # Increased significantly for long articles
        if len(article_summary) > max_length:
          # Find last complete sentence within limit
          truncated = article_summary[:max_length]
          last_period = truncated.rfind('. ')
          if last_period > max_length * 0.85:  # Keep 85% if possible
            article_summary = truncated[:last_period + 1]
          else:
            article_summary = truncated + "..."

        context_elements.append(f"📄 COMPREHENSIVE ARTICLE CONTENT: {article_summary}")

    context_str = '\n'.join(context_elements) if context_elements else "General news context available."

    # Instruction từ semantic search (đưa y chang từ code tham khảo)
    prompt = f"""You are a news caption expert. Your task is to write a news caption that PRIORITIZES the article content and news significance over visual description.

BRIEF VISUAL: {base_caption}

PRIORITY NEWS CONTEXT{source_str}:
Topic: {topic.title()}
{context_str}
CRITICAL INSTRUCTIONS:
1. The NEWS CONTEXT is MORE IMPORTANT than visual details
2. Start with "The image shows" but immediately connect to the news story
3. Use 70% article information + 30% visual description
4. Focus on WHO, WHAT, WHY, WHEN, WHERE from the article
5. Mention specific names, organizations, events from the article
6. Explain the news significance and broader implications
7. Only describe visual elements that support the news story
8. Write 300-350 words prioritizing factual news content

GOOD EXAMPLE:
"The image shows the scene from a significant political development as President Biden announces new healthcare legislation during a White House ceremony. This landmark bill, supported by Democratic leadership including Speaker Pelosi, aims to expand Medicare coverage to millions of Americans. The legislation comes after months of negotiations with pharmaceutical companies and represents a major victory for the administration's domestic agenda. The outdoor ceremony, attended by healthcare advocates and congressional leaders, marks the culmination of a campaign promise made during the 2020 election. The new law is expected to reduce prescription drug costs by 15% and provide coverage for dental and vision services, affecting approximately 12 million seniors nationwide."
YOUR CAPTION:
"""
    return prompt

In [None]:
# TẠO CAPTION
def generate_caption(model, tokenizer, prompt):
    """Generate caption từ model - theo tutorial.ipynb pattern với CHATML"""
    try:
        
        messages = [{"role": "user", "content": prompt}]
        
        
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,  # Must add for generation
            return_tensors="pt",
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        
        
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=350,
            use_cache=True,  
            temperature=1.0,  
            min_p=0.1,  
            pad_token_id=tokenizer.eos_token_id,
        )
        
        # Decode response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Cleanup tensors
        del inputs, outputs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        
        generated_caption = response
        
        
        patterns_to_remove = [
            "<|im_start|>user",
            "<|im_start|>assistant", 
            "<|im_end|>",
            "assistant\n",
            "assistant ",
        ]
        
        for pattern in patterns_to_remove:
            generated_caption = generated_caption.replace(pattern, "")
        
        # Remove the original prompt if it exists in response
        if prompt in generated_caption:
            generated_caption = generated_caption.split(prompt, 1)[-1]
            
        generated_caption = generated_caption.strip()
        
        # Clean up formatting
        if generated_caption.startswith("### Response:"):
            generated_caption = generated_caption[13:].strip()
        
        # Remove trailing incomplete sentences
        sentences = generated_caption.split('.')
        if len(sentences) > 1 and len(sentences[-1].strip()) < 10:
            generated_caption = '.'.join(sentences[:-1]) + '.'
            
        return generated_caption
        
    except Exception as e:
        logging.error(f"Error when generating caption: {e}")
        return ""

In [None]:

def create_backup_filename():
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    return os.path.join(CHECKPOINT_DIR, f"backup_{timestamp}.json")

def load_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        try:
            
            file_size = os.path.getsize(CHECKPOINT_FILE)
            if file_size == 0:
                logging.warning("Checkpoint is empty")
                os.remove(CHECKPOINT_FILE)
                return {}
            
            with open(CHECKPOINT_FILE, 'r', encoding='utf-8') as f:
                content = f.read().strip()
                if not content:
                    logging.warning("Checkpoint does not have content")
                    return {}
                
                checkpoint = json.loads(content)
            
            # Validate checkpoint structure
            if isinstance(checkpoint, dict):
                valid_count = 0
                for query_id, result in checkpoint.items():
                    if isinstance(result, dict) and 'generated_caption' in result:
                        valid_count += 1
                
                if valid_count > 0:
                    logging.info(f"Load checkpoint: {valid_count}/{len(checkpoint)} queries")
                    return checkpoint
                else:
                    logging.warning("Checkpoint is empty")
                    return {}
            else:
                logging.warning("Checkpoint is not follow the right format")
                return {}
                
        except json.JSONDecodeError as e:
            logging.warning(f"Checkpoint file is corrupted (JSON error: {e})")
            try:
                os.remove(CHECKPOINT_FILE)
            except:
                pass
            return {}
        except Exception as e:
            logging.warning(f"Error when loading checkpoint: {e}")
            return {}
    else:
        logging.info("Checkpoint is not found")
        return {}

def save_checkpoint(results, is_final=False):
    """Lưu checkpoint với backup"""
    try:
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        
        if os.path.exists(CHECKPOINT_FILE) and not is_final:
            backup_file = create_backup_filename()
            try:
                import shutil
                shutil.copy2(CHECKPOINT_FILE, backup_file)
                logging.info(f"backup created: {backup_file}")
            except Exception as e:
                logging.warning(f"cannot create backup: {e}")
        
        
        temp_file = CHECKPOINT_FILE + ".tmp"
        with open(temp_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        
        
        import shutil
        shutil.move(temp_file, CHECKPOINT_FILE)
        
            
    except Exception as e:
        logging.error(f"Error when saving checkpoint: {e}")

def get_progress_info(results, total_queries):
    completed = len(results)
    remaining = total_queries - completed
    progress_percent = (completed / total_queries * 100) if total_queries > 0 else 0
    return completed, remaining, progress_percent

def cleanup_old_backups(keep_last=5):
    try:
        backup_pattern = os.path.join(CHECKPOINT_DIR, "backup_*.json")
        import glob
        backup_files = glob.glob(backup_pattern)
        backup_files.sort(key=os.path.getctime, reverse=True)
        
        # Xóa backup cũ
        for backup_file in backup_files[keep_last:]:
            try:
                os.remove(backup_file)
                logging.info(f"Deleted old backup: {os.path.basename(backup_file)}")
            except Exception as e:
                logging.warning(f"Cannot delete {backup_file}: {e}")
                
    except Exception as e:
        logging.warning(f"Error when cleaning backup: {e}")

def show_resume_info(results, total_queries):
    if results:
        completed, remaining, progress = get_progress_info(results, total_queries)
        logging.info("RESUME INFO:")
        logging.info(f"Complete: {completed}/{total_queries} queries ({progress:.1f}%)")
        logging.info(f"Remain: {remaining} queries")
        logging.info("-" * 60)

In [None]:
def load_base_captions():
    try:
        df = pd.read_csv(BASE_CAPTIONS_FILE)
        logging.info(f"load {len(df)} query base captions")
        return df
    except Exception as e:
        logging.error(f"Error when loading base captions: {e}")
        return pd.DataFrame()


def load_top10_articles():
    try:
        df = pd.read_csv(TOP10_ARTICLES_FILE)
        logging.info(f"load {len(df)} query articles")
        return df
    except Exception as e:
        logging.error(f"Error when loading top10 articles: {e}")
        return pd.DataFrame()

def load_database():
    try:
        with open(JSON_DB_FILE, 'r', encoding='utf-8') as f:
            database = json.load(f)
        logging.info(f"loading database with {len(database)} articles")
        return database
    except Exception as e:
        logging.error(f"Error when loading database: {e}")
        return {}

In [None]:
import gc
def load_trained_model():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    gc.collect()
    try:
        from unsloth import FastLanguageModel
        
        model_path = MODEL_PATH
        if not os.path.exists(model_path):
            print("There is no path to the newest model")
            model_path = "./lora_model"  # Fallback to LoRA model
        
        if not os.path.exists(model_path):
            print(f"Cannot find model at {model_path}")
            return None, None
            
        
        
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_path,
            max_seq_length=12000,
            load_in_4bit=True,
            dtype=None,
            device_map=0,  # GPU 0
        )
        
        
        FastLanguageModel.for_inference(model)
        
        
        from unsloth.chat_templates import get_chat_template
        try:
            tokenizer = get_chat_template(
                tokenizer,
                chat_template="qwen3",  
            )
            print("Use chat template: qwen3")
        except Exception as e:
            print(f"Error when setting up qwen3 template: {e}")
            try:
                tokenizer = get_chat_template(
                    tokenizer,
                    chat_template="qwen-2.5",
                )
            except:
                print("Cannot setup chat template")
        
        # Warm up model
        print("warm-up model...")
        test_input = "Test"
        inputs = tokenizer(test_input, return_tensors="pt")
        
        # Kaggle T4 chỉ có cuda:0
        inputs = {k: v.to("cuda:0") for k, v in inputs.items()}
        
        _ = model.generate(
            input_ids=inputs['input_ids'],
            max_new_tokens=10,
            use_cache=True,
            temperature=1.0,
        )
        del inputs, _
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("Warm-up complete!")
        
        print("Base model + LoRA adapter")
        return model, tokenizer
        
    except Exception as e:
        print(f"Error when loading model: {e}")
        import traceback
        traceback.print_exc()  
        return None, None

In [None]:
# Load model
model, tokenizer = load_trained_model()

In [None]:

print("=" * 60)
    
if model is None:
    print("Cannot load model")
else: 
    
    base_captions = load_base_captions()
    top10_df = load_top10_articles()
    database = load_database()
    
    if base_captions.empty or top10_df.empty or not database:
        print("Cannot load data")
        
    
    # Load checkpoint
    results = load_checkpoint()
    
    
    show_resume_info(results, len(top10_df))
    
    
    cleanup_old_backups()
    
    
    final_results = []
    
    try:
        import time
        start_time = time.time()
        
        
        for idx, row in tqdm(top10_df.iterrows(), total=len(top10_df), desc="Generating captions", 
                            unit="query", dynamic_ncols=True):
            query_id = row['query_id']
            
            
            if query_id in results:
                print(f"Skip query {query_id}")
                final_results.append(results[query_id])
                continue
            
            
            base_caption = base_captions.loc[base_captions['query_id'] == query_id, 'caption'].values[0]
            if not base_caption:
                print(f"Cannot find base caption for query {query_id}")
                continue
            
            article_id_1 = row.get('article_id_1', '')
            if not article_id_1:
                print(f"Cannot find article_id_1 cho query {query_id}")
                continue
            
            
            article_data = database.get(article_id_1, {})
            
            key_info = extract_key_info(article_data)
            
            
            
            prompt = create_advanced_prompt(base_caption, key_info)
            
            
            if torch.cuda.is_available():
                gpu_memory_before = torch.cuda.memory_allocated() / 1024**3
                
            print(f"Processing query {query_id}...")
            query_start_time = time.time()
            
            try:
                prep_start = time.time()
                generated_caption = generate_caption(model, tokenizer, prompt)
                gen_time = time.time() - prep_start
                
                if len(results) <= 5:  
                    print(f"Generation time: {gen_time:.1f}s")
                    
            except Exception as e:
                print(f"Lỗi khi generate caption cho query {query_id}: {e}")
                generated_caption = ""
            
            query_end_time = time.time()
            query_duration = query_end_time - query_start_time
            
            if torch.cuda.is_available():
                gpu_memory_after = torch.cuda.memory_allocated() / 1024**3
                memory_change = gpu_memory_after - gpu_memory_before
                print(f"Query time: {query_duration:.1f}s, GPU memory: {gpu_memory_after:.1f}GB (+{memory_change:.2f}GB)")
            
            if generated_caption and len(generated_caption.strip()) > 10:
                
                result_row = {
                    'query_id': query_id,
                    'generated_caption': generated_caption
                }
                
                final_results.append(result_row)
                results[query_id] = result_row
                
                query_time = time.time() - start_time if 'start_time' in locals() else 0
                avg_time = query_time / len(results) if len(results) > 0 else 0
                
                print(f"Complete query {query_id} ({avg_time:.1f}s/query)")
                
                
                completed, remaining, progress = get_progress_info(results, len(top10_df))
                eta_seconds = remaining * avg_time if avg_time > 0 else 0
                eta_minutes = eta_seconds / 60
                
                print(f"📊 Progress: {completed}/{len(top10_df)} ({progress:.1f}%) - ETA: {eta_minutes:.1f} minutes")
                
                if len(results) % 5 == 0:
                    save_checkpoint(results)
                    
                if torch.cuda.is_available():
                    if len(results) % 5 == 0:
                        current_memory = torch.cuda.memory_allocated() / 1024**3
                        print(f"GPU memory: {current_memory:.1f}GB")
                
                
                if len(results) % 50 == 0:
                    backup_file = create_backup_filename()
                    try:
                        with open(backup_file, 'w', encoding='utf-8') as f:
                            json.dump(results, f, ensure_ascii=False, indent=2)
                        print(f"Milestone backup at {len(results)} queries: {backup_file}")
                    except Exception as e:
                        print(f"Cannot create milestone backup: {e}")
                    
            else:
                print(f"Cannot generate caption for query {query_id}")
    
    except KeyboardInterrupt:
        print("Paused by user. Saving checkpoint...")
        save_checkpoint(results, is_final=True)
        completed, remaining, progress = get_progress_info(results, len(top10_df))
        print(f"Stop at: {completed}/{len(top10_df)} queries ({progress:.1f}%)")
        
    
    except Exception as e:
        print(f"Error while processing: {e}")
        save_checkpoint(results, is_final=True)
        import traceback
        traceback.print_exc()
    
    finally:
        
        if final_results:
            
            os.makedirs(OUTPUT_DIR, exist_ok=True)
            
            
            with open(FINAL_OUTPUT_JSON, 'w', encoding='utf-8') as f:
                json.dump(final_results, f, ensure_ascii=False, indent=2)
            
            
            final_df = pd.DataFrame(final_results)
            final_df.to_csv(FINAL_OUTPUT_CSV, index=False, encoding='utf-8')
            
            print(f"Complete! Processed {len(final_results)} queries")
            print(f"Results are saved at:")
            print(f"   - JSON: {FINAL_OUTPUT_JSON}")
            print(f"   - CSV: {FINAL_OUTPUT_CSV}")
        
        save_checkpoint(results, is_final=True)
        
        
        if results:
            final_backup = os.path.join(CHECKPOINT_DIR, "final_backup.json")
            try:
                with open(final_backup, 'w', encoding='utf-8') as f:
                    json.dump(results, f, ensure_ascii=False, indent=2)
                print(f"Final backup is saved at: {final_backup}")
            except Exception as e:
                print(f"Cannot create final backup: {e}")