### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3
    !pip install --no-deps unsloth

In [None]:
!pip install sentence-transformers numpy

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Cấu hình model
MODEL_CONFIG = {
    "model_name": "unsloth/DeepSeek-R1-0528-Qwen3-8B",
    "max_seq_length": 12000,
    "load_in_4bit": True,
    "load_in_8bit": False,
    "full_finetuning": False,
}

In [None]:
# Cấu hình LoRA
LORA_CONFIG = {
    "r": 256,
    "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    "lora_alpha": 256,
    "lora_dropout": 0,
    "bias": "none",
    "use_gradient_checkpointing": "unsloth",
    "random_state": 3407,
    "use_rslora": False,
    "loftq_config": None,
}


In [None]:
# Cấu hình training
TRAINING_CONFIG = {
    "output_dir": "./sft_results_upgraded_prompt",
    "dataset_text_field": "text",
    "max_seq_length": 12000,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 8,
    "warmup_steps": 10,
    "num_train_epochs": 2,
    "max_steps": -1,
    "learning_rate": 5e-5,
    "logging_steps": 1,
    "optim": "paged_adamw_8bit",
    "weight_decay": 0.01,
    "lr_scheduler_type": "cosine",
    "seed": 3407,
    "report_to": "none",
    "dataloader_num_workers": 3,
    "dataset_num_proc": 1,
}

In [None]:
# Ngưỡng xử lý content
CONTENT_LENGTH_THRESHOLD_FOR_SUMMARY = 1000
MAX_SUMMARY_LEN_IN_PROMPT = 6000
MAX_LEN_SEMANTIC_EXTRACTION = 6000


NUM_PROC = 1

SAVE_DIR = "/content/drive/MyDrive/"

DATASET_PATH = "output_merged.json"


In [None]:
import os
import json
import re
import numpy as np
from tqdm import tqdm
USE_SEMANTIC_SEARCH = True
MODEL_NAME_SEMANTIC = 'all-MiniLM-L12-v2'
sentence_model_global = None

if USE_SEMANTIC_SEARCH:
    try:
        from sentence_transformers import SentenceTransformer
        sentence_model_global = SentenceTransformer(MODEL_NAME_SEMANTIC)
    except ImportError:
        USE_SEMANTIC_SEARCH = False
    except Exception as e:
        USE_SEMANTIC_SEARCH = False


In [None]:
!gdown 1F0EMvGBm-l4iXV11Im3zaGOPbifsWxfQ #Link of training dataset for caption generation stage

In [None]:

def _identify_topic(text: str, titles: list) -> str:
    all_text = ' '.join(titles).lower() + ' ' + text.lower()
    topic_keywords = {
        'technology': ['technology', 'tech', 'software', 'AI', 'robot', 'digital', 'computer', 'innovation', 'platform', 'data'],
        'business': ['business', 'company', 'market', 'economy', 'trade', 'finance', 'investment', 'ceo', 'gdp', 'stock', 'enterprise'],
        'politics': ['election', 'president', 'government', 'policy', 'political', 'minister', 'congress', 'parliament', 'senate', 'legislation', 'bill', 'campaign'],
        'sports': ['game', 'player', 'team', 'match', 'championship', 'athlete', 'sport', 'win', 'tournament', 'olympics', 'score'],
        'health': ['health', 'medical', 'doctor', 'patient', 'disease', 'treatment', 'hospital', 'vaccine', 'pandemic', 'healthcare', 'medicine'],
        'environment': ['climate', 'environment', 'pollution', 'renewable', 'energy', 'sustainable', 'carbon', 'green', 'emissions', 'ecology'],
        'entertainment': ['movie', 'film', 'actor', 'music', 'artist', 'show', 'entertainment', 'celebrity', 'concert', 'awards'],
        'science': ['research', 'study', 'scientist', 'discovery', 'experiment', 'science', 'data', 'analysis', 'journal', 'university'],
        'education': ['education', 'school', 'university', 'college', 'student', 'teacher', 'learning', 'curriculum'],
        'social issues': ['social', 'community', 'human rights', 'inequality', 'poverty', 'justice', 'protest'],
        'world affairs': ['international', 'global', 'world', 'geopolitics', 'diplomacy', 'conflict', 'united nations'],
        'food': ['food', 'restaurant', 'chef', 'meal', 'cuisine', 'cooking', 'recipe', 'dining', 'agriculture'],
        'travel': ['travel', 'tourism', 'destination', 'flight', 'hotel', 'vacation', 'journey', 'trip', 'airport']
    }
    topic_scores = {}
    for topic, keywords in topic_keywords.items():
        score = sum(1 for keyword in keywords if keyword in all_text)
        if score > 0:
            topic_scores[topic] = score
    if topic_scores:
        return max(topic_scores, key=topic_scores.get)
    return 'general'

def _extract_organizations(text: str) -> list:
    organizations = []

    patterns = [
        r'\b[A-Z]{2,6}\b',
        r'\b[A-Z][a-zA-Z]*(?:\s+(?:and|of|the|for)\s+)?[A-Z][a-zA-Z]*(?:\s+(?:Inc|Corp|Ltd|LLC|Co|Group|Holdings|Foundation|Association|Organization|Agency|Department|University|Institute|College|School|Council|Committee|Party|Union|Bank|Studio|Network))\b',
        r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\s+(?:Inc\.|Corp\.|Ltd\.|LLC|Co\.|Group|Foundation|Association|Organization|Agency|Department|University|Institute|College|School|Council|Committee|Party|Union|Bank|Studio|Network)\b',
    ]

    known_orgs = [
        'Google', 'Microsoft', 'Apple', 'Amazon', 'Meta', 'Facebook', 'Twitter', 'Netflix', 'Tesla',
        'United Nations', 'World Health Organization', 'European Union', 'NATO', 'NASA',
        'CNN', 'BBC', 'Reuters', 'Associated Press', 'New York Times', 'The Guardian',
    ]

    for pattern in patterns:
        matches = re.findall(pattern, text)
        organizations.extend(matches)

    for org in known_orgs:
        if re.search(r'\b' + re.escape(org) + r'\b', text, re.IGNORECASE):
            organizations.append(org)


    processed_orgs = []
    common_words_in_org_names = {'The', 'A', 'An', 'Of', 'And', 'For'}
    for org in organizations:
        org_stripped = org.strip()

        if len(org_stripped) <= 1 and org_stripped.isupper():
            continue
        if org_stripped.isupper() and len(org_stripped) > 6:
             if org_stripped not in known_orgs:
                continue

        if org_stripped in common_words_in_org_names:
            continue
        processed_orgs.append(org_stripped)


    final_orgs = list(set(processed_orgs))
    final_orgs.sort(key=lambda x: (len(x.split()), x.isupper()), reverse=True)
    return final_orgs[:8]


def _extract_people(text: str) -> list:
    people = []

    patterns = [
        r'\b(?:Mr\.|Mrs\.|Ms\.|Miss|Dr\.|Prof\.|President|CEO|Minister|Director|Ambassador|General|Captain|Chef|Senator|Governor|Mayor|Councillor|Judge|Justice)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,2})\b',

        r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z\'\-]+){1,3})\b'
    ]

    for pattern in patterns:
        matches = re.findall(pattern, text)

        if isinstance(matches, list) and matches and isinstance(matches[0], tuple):
             people.extend([m[0] for m in matches if m[0]])
        else:
             people.extend(matches)

    non_name_keywords = {
        'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December',
        'Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday',
        'Today', 'Yesterday', 'Tomorrow', 'Week', 'Month', 'Year',
        'Street', 'Road', 'Avenue', 'City', 'State', 'Country', 'County', 'Park', 'Building', 'Center', 'Plaza', 'Square',
        'The', 'This', 'That', 'These', 'Those', 'And', 'But', 'For', 'With', 'From', 'About', 'Under', 'Over',
        'Is', 'Are', 'Was', 'Were', 'Has', 'Have', 'Had', 'Says', 'Said', 'Told',
        'North', 'South', 'East', 'West',
        'Company', 'Corporation', 'Incorporated', 'Limited', 'Organization', 'Department', 'University', 'Institute', 'College', 'School' # Tên tổ chức
    }

    processed_people = []
    for p_match in people:
        p = p_match.strip()
        words = p.split()

        if len(words) >= 2 and all(word[0].isupper() for word in words) and not all(word.isupper() for word in words) and not any(word in non_name_keywords for word in words) and len(p)>3 :
            processed_people.append(p)
        elif len(words) == 1 and p[0].isupper() and p not in non_name_keywords and len(p)>3 and not p.isupper(): # Tên một từ

            processed_people.append(p)



    final_people = list(set(processed_people))
    final_people.sort(key=len, reverse=True)
    return final_people[:8]

def _extract_locations(text: str) -> list:
    locations = []

    predefined_locations = [
        'Vietnam', 'United States', 'China', 'India', 'Japan', 'Germany', 'United Kingdom', 'France', 'Canada', 'Australia', 'Russia', 'Brazil', 'South Korea', 'Italy', 'Spain',
        'New York', 'Los Angeles', 'Chicago', 'London', 'Paris', 'Berlin', 'Tokyo', 'Beijing', 'Shanghai', 'Seoul', 'Moscow', 'Singapore', 'Sydney', 'Toronto', 'Rome', 'Madrid', 'Washington D.C.'
        'San Francisco', 'Silicon Valley'
    ]
    for loc in predefined_locations:
        if re.search(r'\b' + re.escape(loc) + r'\b', text, re.IGNORECASE):
            locations.append(loc)

    patterns = [
        r'\b([A-Z][a-zA-Z\']+)(?:\s+(?:of|de|the|la)\s+)?(?:[A-Z][a-zA-Z\']+){0,3}(?:,\s*[A-Z][a-zA-Z\.\s]+)?\b'
    ]
    for pattern in patterns:
        matches = re.findall(pattern, text)
        locations.extend(match.strip() for match in matches if len(match.strip()) > 2) # Lọc kết quả


    non_location_keywords = _extract_people(text) + _extract_organizations(text) + [
        'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December',
        'Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday', 'Mr', 'Ms', 'Dr'
    ]
    non_location_keywords_lower = {k.lower() for k in non_location_keywords}

    processed_locations = []
    for loc in locations:
        loc_stripped = loc.strip().rstrip(',.')

        if len(loc_stripped.split()) == 1 and (loc_stripped.isdigit() or (loc_stripped.isupper() and loc_stripped not in ['US', 'UK', 'EU'])):
            continue
        if loc_stripped.lower() not in non_location_keywords_lower and len(loc_stripped)>2:
            if not (loc_stripped.lower().startswith("the ") and len(loc_stripped.split()) < 3):
                 processed_locations.append(loc_stripped)


    final_locations = list(set(processed_locations))
    final_locations.sort(key=len, reverse=True)
    return final_locations[:8]


def _extract_events(text: str) -> list:
    events = []
    patterns = [
        r'\b(?:the\s+)?([A-Z][a-zA-Z0-9\s\'\-]+(?:Conference|Summit|Forum|Meeting|Festival|Games|Olympics|Championship|Cup|Awards|Exhibition|Show|Ceremony|Campaign|Initiative|Project|Program|Operation|War|Battle|Treaty|Accord|Act|Bill|Law|Debate|Election|Crisis|Pandemic|Outbreak|Attack|Incident|Disaster))\b',
        r'\b([A-Z][a-zA-Z]+\s+(?:World Cup|Olympic Games|Grand Prix|Open|Summit|Conference|Festival))\b',
        r'\b(?:G7 Summit|G20 Summit|COP\d+\sConference)\b',
        r'\b\d{4}\s+(?:Summer|Winter)\s+Olympics\b',
    ]

    known_events = ['World War I', 'World War II', 'Vietnam War', 'Cold War', 'September 11 attacks', 'COVID-19 Pandemic']

    for pattern in patterns:
        matches = re.findall(pattern, text)

        events.extend([m if isinstance(m, str) else m[0] for m in matches])

    for event in known_events:
        if re.search(r'\b' + re.escape(event) + r'\b', text, re.IGNORECASE):
            events.append(event)

    processed_events = [event.strip().rstrip(',.') for event in events if len(event.strip()) > 4]
    final_events = list(set(processed_events))
    final_events.sort(key=len, reverse=True)
    return final_events[:5]

def _extract_dates(text: str, provided_date: str = None) -> list:
    dates = []
    if provided_date:
        dates.append(provided_date.strip())

    patterns = [

        r'\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,\s+\d{4}\b',
        r'\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{4}\b',

        r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b',
        r'\b\d{4}[/-]\d{1,2}[/-]\d{1,2}\b',

        r'\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{4}\b',

        r'\b(?:Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday)(?:,\s*(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?,\s+\d{4})?\b',

        r'\b(?:yesterday|today|tomorrow|last\s+week|next\s+week|last\s+month|next\s+month|this\s+year|last\s+year|next\s+year)\b',

        r'\b(?:in|during|on|by|since|until|from|the\s+year\s+of)\s+(\d{4})\b',
        r'\b(\d{4})\b'
    ]
    for pattern in patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        if pattern.endswith(r'\b(\d{4})\b') or pattern.endswith(r'\s+(\d{4})\b') :
            dates.extend(m for m in matches if 1900 <= int(m) <= 2050)
        else:
            dates.extend(m.strip() for m in matches)



    final_dates = []
    current_year = 2025
    for d_match in dates:
        d = d_match.strip().rstrip(',.')
        if d.isdigit() and len(d) == 4:
            year = int(d)
            if 1900 <= year <= current_year + 5:
                final_dates.append(d)
        elif len(d) > 3:
            final_dates.append(d)

    final_dates = list(set(final_dates))

    final_dates.sort(key=lambda x: (len(x), x), reverse=True)
    return final_dates[:5]

def _extract_key_terms(text: str, title:str) -> list:
    combined_text = (title.lower() + " ") * 3 + text.lower()


    text_no_urls = re.sub(r'http\S+|www.\S+|\S+@\S+', '', combined_text)


    words = re.findall(r'\b[a-zA-Z0-9][a-zA-Z0-9\-\']*[a-zA-Z0-9]\b', text_no_urls)


    stop_words = {
        'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'being', 'been', 'this', 'that', 'these', 'those',
        'and', 'or', 'but', 'if', 'of', 'at', 'by', 'for', 'with', 'about', 'to', 'from', 'in', 'out', 'on',
        'it', 'its', 'he', 'she', 'they', 'them', 'his', 'her', 'their', 'you', 'your', 'we', 'our',
        'i', 'me', 'my', 'mine', 'us', 'ours', 'myself', 'yourself', 'himself', 'herself', 'itself', 'ourselves', 'yourselves', 'themselves',
        'what', 'which', 'who', 'whom', 'whose', 'why', 'how', 'when', 'where',
        'all', 'any', 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very',
        'can', 'will', 'just', 'don', 'should', 'now', 'do', 'does', 'did', 'doing', 'said', 'says', 'also', 'get', 'go', 'make', 'know', 'see', 'use', 'find', 'tell', 'ask', 'work', 'seem', 'feel', 'try', 'leave', 'call',
        'one', 'two', 'three', 'january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december',
        'mr', 'mrs', 'ms', 'dr', 'prof', 'inc', 'ltd', 'corp',
        'news', 'report', 'story', 'article', 'image', 'photo', 'picture', 'video', 'caption', 'description',
        'people', 'person', 'man', 'woman', 'child', 'children', 'group', 'team',
        'world', 'country', 'city', 'government', 'company', 'organization', 'event', 'system', 'part', 'number', 'way', 'thing', 'day', 'year', 'time', 'today', 'content', 'information', 'context', 'detail', 'example'
    }


    proper_nouns_phrases = re.findall(r'\b[A-Z][a-zA-Z0-9\-\']*(?:\s+[A-Z][a-zA-Z0-9\-\']*){0,3}\b', text)
    filtered_proper_nouns = []
    for phrase in proper_nouns_phrases:
        p_words = phrase.split()
        if not all(word.lower() in stop_words for word in p_words) and \
           not (len(p_words) == 1 and p_words[0].lower() in stop_words) and \
           len(phrase.strip()) > 2 :
            filtered_proper_nouns.append(phrase.strip())


    filtered_words = [word for word in words if word not in stop_words and len(word) > 2 and not word.isdigit()]

    term_freq = {}
    for term in filtered_proper_nouns + title.lower().split():
        if term.lower() not in stop_words and len(term)>2:
            term_freq[term.lower()] = term_freq.get(term.lower(), 0) + 2
    for word in filtered_words:
        term_freq[word] = term_freq.get(word, 0) + 1


    sorted_terms = sorted(term_freq.items(), key=lambda x: (x[1], len(x[0].split()), len(x[0])), reverse=True)

    final_terms = []
    seen_lower = set()
    for term, freq in sorted_terms:
        if term not in seen_lower:
            original_case_term = term
            for pn in filtered_proper_nouns:
                if pn.lower() == term:
                    original_case_term = pn
                    break
            final_terms.append(original_case_term)
            seen_lower.add(term)
        if len(final_terms) >= 10:
            break

    return final_terms


def _extract_numbers(text: str) -> list:

    patterns = [

        r'(?:\$|€|£|¥|USD|EUR|GBP|JPY|VND)\s*\d+(?:[.,]\d{3})*(?:[.,]\d+)?(?:\s*(?:million|billion|trillion|thousand|K|M|B|T))?\b',
        r'\b\d+(?:[.,]\d{3})*(?:[.,]\d+)?\s*(?:dollars?|euros?|pounds?|yen|đồng|USD|EUR|GBP|JPY|VND)(?:\s*(?:million|billion|trillion|thousand|K|M|B|T))?\b',

        r'\b\d+(?:[.,]\d+)?\s*%(?:\s*points)?\b',
        r'\b\d+(?:[.,]\d+)?\s*(?:percent|per\s+cent|percentage\s+points?)\b',

        r'\b\d+(?:[.,]\d{3})*(?:[.,]\d+)?\s*(?:people|users|viewers|votes|cases|deaths|infections|jobs|companies|countries|cities|members|students|teachers|schools|hospitals|doctors|patients|items|products|services|cars|houses|buildings|acres|hectares|tons|kg|grams|liters|gallons|km|kilometers|meters|miles|feet|gb|mb|tb|hz|watts|volts|degrees|°C|°F|points|barrels|shares|pages|chapters|articles|sections|votes)\b',

        r'\b\d+(?:[.,]\d+)?\s*(?:to|-|–)\s*\d+(?:[.,]\d+)?\b',

        r'\b\d{1,3}(?:[.,]\d{3})*(?:[.,]\d+)?\b',
        r'\b(?:age|aged)\s+\d+\b',
        r'\b\d+\s*years?\s*old\b',
    ]
    numbers = []
    for pattern in patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        numbers.extend(matches)



    processed_numbers = []
    for num_match in numbers:
        num_str = num_match.strip()
        try:
            if re.fullmatch(r'\d+', num_str) and int(num_str) < 10:
                if not any(unit in text[text.find(num_str):text.find(num_str)+len(num_str)+10].lower() for unit in ['million', 'billion', 'thousand', '%', 'percent', 'degree']):
                    continue
        except ValueError:
            pass
        processed_numbers.append(num_str)

    final_numbers = list(set(processed_numbers))
    final_numbers.sort(key=lambda x: any(kw in x.lower() for kw in ['million', 'billion', 'trillion', 'percent', '%', '$', '€', '£']), reverse=True)
    return final_numbers[:10]


def extract_key_info_from_json_item(item_content: str, item_title: str, item_date: str) -> dict:
    if not item_content and not item_title:
        return {
            'titles': [], 'sources': [], 'topic': 'general', 'organizations': [],
            'people': [], 'locations': [], 'events': [], 'numbers': [],
            'dates': [item_date.strip()] if item_date else [], 'key_terms': [], 'context': ''
        }

    full_text_for_extraction = (item_title if item_title else "") + ' ' + (item_content if item_content else "")


    extracted_title = [item_title.strip()] if item_title else []
    extracted_dates = _extract_dates(full_text_for_extraction, item_date)
    extracted_organizations = _extract_organizations(full_text_for_extraction)


    temp_people = _extract_people(full_text_for_extraction)
    extracted_people = [p for p in temp_people if p not in extracted_organizations and not any(org_part in p for org_part in " ".join(extracted_organizations).split() if len(org_part)>3)]


    temp_locations = _extract_locations(full_text_for_extraction)
    extracted_locations = [l for l in temp_locations if l not in extracted_organizations and l not in extracted_people and not any(org_part in l for org_part in " ".join(extracted_organizations).split() if len(org_part)>3)]


    info = {
        'titles': extracted_title,
        'sources': [],
        'topic': _identify_topic(full_text_for_extraction, extracted_title),
        'organizations': extracted_organizations,
        'people': extracted_people,
        'locations': extracted_locations,
        'events': _extract_events(full_text_for_extraction),
        'numbers': _extract_numbers(full_text_for_extraction),
        'dates': extracted_dates,
        'key_terms': _extract_key_terms(full_text_for_extraction, item_title if item_title else ""),
        'context': item_content if item_content else ""
    }
    return info

In [None]:
def _semantic_article_extraction(full_content: str, base_caption: str, key_info: dict) -> str:
    global sentence_model_global
    if not USE_SEMANTIC_SEARCH or sentence_model_global is None:
        return full_content[:5000] + "..." if len(full_content) > 5000 else full_content

    try:

        sentences = re.split(r'(?<=[.!?])\s+', full_content.strip())
        sentences = [s.strip() for s in sentences if len(s.strip()) > 15]

        if not sentences or len(sentences) < 5 :
            return full_content[:5000] + "..." if len(full_content) > 5000 else full_content


        chunk_size = 3
        overlap = 1
        chunks = []

        for i in range(0, len(sentences) - chunk_size + 1, chunk_size - overlap):
            chunk = ' '.join(sentences[i : i + chunk_size])
            if chunk.strip():
                chunks.append(chunk)

        if not chunks:
            return full_content[:5000] + "..." if len(full_content) > 5000 else full_content


        search_queries = [base_caption]
        if key_info.get('titles'): search_queries.append(f"Title context: {key_info['titles'][0]}")
        if key_info.get('key_terms'): search_queries.append(f"Key terms: {', '.join(key_info['key_terms'][:3])}")
        if key_info.get('people'): search_queries.append(f"People involved: {', '.join(key_info['people'][:2])}")
        if key_info.get('organizations'): search_queries.append(f"Organizations: {', '.join(key_info['organizations'][:2])}")
        if key_info.get('locations'): search_queries.append(f"Locations: {', '.join(key_info['locations'][:2])}")


        chunk_embeddings = sentence_model_global.encode(chunks, show_progress_bar=False, batch_size=128)
        query_embeddings = sentence_model_global.encode(search_queries, show_progress_bar=False, batch_size=128)


        all_sim_scores = []
        for query_emb in query_embeddings:
            sim_scores = np.dot(chunk_embeddings, query_emb) / (np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_emb))
            all_sim_scores.append(sim_scores)


        combined_scores = np.max(np.array(all_sim_scores), axis=0)


        num_top_chunks = min(max(5, int(len(chunks) * 0.3)), 10)

        top_indices_by_score = np.argsort(combined_scores)[-num_top_chunks:]


        selected_chunks_with_scores = []
        for idx in top_indices_by_score:
            selected_chunks_with_scores.append((idx, chunks[idx], combined_scores[idx]))


        selected_chunks_with_scores.sort(key=lambda x: x[0])


        relevant_text_parts = [chunk_data[1] for chunk_data in selected_chunks_with_scores]


        final_relevant_sentences = set()

        for part in relevant_text_parts:
            for s_in_chunk in re.split(r'(?<=[.!?])\s+', part.strip()):
                 if s_in_chunk.strip():
                    final_relevant_sentences.add(s_in_chunk.strip())



        first_sentences = [s.strip() for s in sentences[:min(3, len(sentences))] if s.strip()]
        last_sentences = [s.strip() for s in sentences[max(0, len(sentences)-2):] if s.strip()]


        combined_set = set(first_sentences) | final_relevant_sentences | set(last_sentences)



        sentence_to_original_index = {sentence: i for i, sentence in enumerate(sentences)}

        sorted_combined_sentences = sorted(list(combined_set), key=lambda s: sentence_to_original_index.get(s, float('inf')))


        extracted_text = ' '.join(sorted_combined_sentences)


        max_len = 6000
        if len(extracted_text) > max_len:
            extracted_text = extracted_text[:max_len]
            last_sentence_end = extracted_text.rfind('.')
            if last_sentence_end > 0:
                extracted_text = extracted_text[:last_sentence_end+1]
            else:
                extracted_text += "..."
        return extracted_text

    except Exception as e:
        print(f"Erron in semantic search: {e}")
        import traceback
        traceback.print_exc()
        return full_content[:5000] + "..." if len(full_content) > 5000 else full_content

In [None]:
def create_advanced_prompt(base_caption: str, key_info: dict) -> str:
    topic = key_info.get('topic', 'general')
    source_str = ""

    organizations = key_info.get('organizations', [])[:3]
    people = key_info.get('people', [])[:3]
    locations = key_info.get('locations', [])[:3]
    events = key_info.get('events', [])[:2]
    numbers = key_info.get('numbers', [])[:3]
    dates = key_info.get('dates', [])[:3]

    full_context = key_info.get('context', '')
    article_title = key_info.get('titles', [''])[0]

    context_elements = []
    if article_title:
        context_elements.append(f"MAIN STORY: {article_title}")
    if organizations:
        context_elements.append(f"KEY ORGANIZATIONS: {', '.join(organizations)}")
    if people:
        context_elements.append(f"PEOPLE INVOLVED: {', '.join(people)}")
    if locations:
        context_elements.append(f"LOCATIONS: {', '.join(locations)}")
    if events:
        context_elements.append(f"EVENTS: {', '.join(events)}")
    if numbers:
        context_elements.append(f"KEY FIGURES: {', '.join(numbers)}")
    if dates:
        context_elements.append(f"TIMELINE: {', '.join(dates)}")

    article_summary = ""

    if full_context:
        if USE_SEMANTIC_SEARCH and sentence_model_global and len(full_context) > 3000:
            article_summary = _semantic_article_extraction(full_context, base_caption, key_info)
        else:
            article_length = len(full_context)
            if article_length <= 3000:
                # Short article: use full content
                article_summary = full_context
            elif article_length <= 8000:
                # Medium article: smart sampling with higher density
                sentences = full_context.split('. ')
                total_sentences = len(sentences)

                # Take more sentences for better coverage
                key_sentences = []

                # First 10 sentences (usually most important)
                key_sentences.extend(sentences[:10])

                # Every 3rd sentence from the middle section
                middle_start = 8
                middle_end = total_sentences - 4
                for i in range(middle_start, middle_end, 3):
                  if i < total_sentences:
                    key_sentences.append(sentences[i])

                # Last 4 sentences (conclusions, outcomes)
                key_sentences.extend(sentences[-4:])

                # Remove duplicates while preserving order
                seen = set()
                unique_sentences = []
                for sentence in key_sentences:
                  if sentence.strip() and sentence not in seen:
                    seen.add(sentence)
                    unique_sentences.append(sentence)

                article_summary = '. '.join(unique_sentences)
            else:
              article_summary = full_context[:5000]

              # Add key sentences from the rest
              remaining_content = full_context[5000:]
              remaining_sentences = remaining_content.split('. ')

              # Add every 5th sentence from remaining content
              additional_sentences = []
              for i in range(0, len(remaining_sentences), 5):
                if len(additional_sentences) < 20:
                  additional_sentences.append(remaining_sentences[i])

              if additional_sentences:
                article_summary += ". " + '. '.join(additional_sentences)


        max_length = 6000
        if len(article_summary) > max_length:
          truncated = article_summary[:max_length]
          last_period = truncated.rfind('. ')
          if last_period > max_length * 0.85:
            article_summary = truncated[:last_period + 1]
          else:
            article_summary = truncated + "..."

        context_elements.append(f"COMPREHENSIVE ARTICLE CONTENT: {article_summary}")

    context_str = '\n'.join(context_elements) if context_elements else "General news context available."


    prompt = f"""You are a news caption expert. Your task is to write a news caption that PRIORITIZES the article content and news significance over visual description.

BRIEF VISUAL: {base_caption}

PRIORITY NEWS CONTEXT{source_str}:
Topic: {topic.title()}
{context_str}

CRITICAL INSTRUCTIONS:
1. The NEWS CONTEXT is MORE IMPORTANT than visual details
2. Start with "The image shows" but immediately connect to the news story
3. Use 70% article information + 30% visual description (ensure the connectivity between each visual element and the article content)
4. Focus on WHO, WHAT, WHY, WHEN, WHERE from the article
5. Mention specific names, organizations, events from the article
6. Explain the news significance and broader implications
7. Only describe visual elements that support the news story
8. Write 300-350 words prioritizing factual news content

GOOD EXAMPLE (prioritizing news over visuals):
"The image shows the scene from a significant political development as President Biden announces new healthcare legislation during a White House ceremony. This landmark bill, supported by Democratic leadership including Speaker Pelosi, aims to expand Medicare coverage to millions of Americans. The legislation comes after months of negotiations with pharmaceutical companies and represents a major victory for the administration's domestic agenda. The outdoor ceremony, attended by healthcare advocates and congressional leaders, marks the culmination of a campaign promise made during the 2020 election. The new law is expected to reduce prescription drug costs by 15% and provide coverage for dental and vision services, affecting approximately 12 million seniors nationwide."

YOUR CAPTION (prioritize article news content over visual description):"""
    return prompt

In [None]:

import os
import json
from tqdm import tqdm
from datasets import Dataset
import pandas as pd

def create_sample_dataset():
    """Tạo file dataset mẫu nếu không tồn tại"""
    sample_data = [
        {
            "image_id": "sample_img_001",
            "base_caption": "Several politicians are seated at a long table during a formal meeting in a well-lit conference room.",
            "title": "International Summit Addresses Global Economic Challenges",
            "content": "Leaders from twenty major economies convened today in Geneva for the annual Global Economic Forum (GEF). The summit, running from June 3rd to June 5th, 2025, aims to tackle pressing issues such as rising inflation, supply chain disruptions, and the future of digital currencies. Keynote speaker Dr. Aris Thorne, head of the World Monetary Institute (WMI), presented a stark outlook, urging coordinated international action. Discussions also involved representatives from major corporations like OmniCorp and TechSolutions Inc. A significant portion of the agenda is dedicated to sustainable development goals. The GEF's final communiqué is expected to outline a roadmap for global recovery. Last year's summit resulted in a $100 billion pledge for infrastructure projects.",
            "date": "2025-06-03",
            "label": "The image captures a moment from the Global Economic Forum in Geneva, where leaders from twenty nations gathered between June 3-5, 2025, to discuss global inflation and supply chain issues. Dr. Aris Thorne of the World Monetary Institute highlighted the need for coordinated action. The summit, also attended by corporations like OmniCorp, is focusing on sustainable development and digital currencies, with a final roadmap for economic recovery anticipated."
        },
        {
            "image_id": "sample_img_002",
            "base_caption": "A scientist in a lab coat is looking intently at a test tube with blue liquid.",
            "title": "Breakthrough in Cancer Research Announced by MedSynth Labs",
            "content": "MedSynth Labs, a leading biomedical research institute based in Cambridge, today announced a significant breakthrough in the development of a novel targeted therapy for lung cancer. The research, published in the 'Journal of Oncology', details a new compound, LX-7, that has shown remarkable efficacy in preclinical trials, shrinking tumors by up to 80%. The research team, led by Dr. Lena Hanson, has been working on this project for over seven years. \"This could revolutionize how we treat certain aggressive forms of lung cancer,\" Dr. Hanson stated at a press conference. MedSynth Labs plans to begin Phase 1 human trials by early 2026. The study was partially funded by a $5 million grant from the National Health Foundation (NHF). This development offers new hope for patients worldwide.",
            "date": "2025-06-04",
            "label": "The image likely depicts a scientist at MedSynth Labs in Cambridge, a facility that recently announced a breakthrough in lung cancer research with a new compound, LX-7. Led by Dr. Lena Hanson and published in the 'Journal of Oncology' on June 4, 2025, the therapy showed an 80% tumor reduction in preclinical trials. MedSynth Labs, supported by a $5 million NHF grant, aims for human trials by early 2026, offering new hope for cancer treatment."
        }
    ]
    with open(DATASET_PATH, 'w', encoding='utf-8') as f_json:
        json.dump(sample_data, f_json, indent=4)

def load_dataset():
    if not os.path.exists(DATASET_PATH):
        create_sample_dataset()
    else:
        print(f"Found file dataset at: {DATASET_PATH}")


    try:
        with open(DATASET_PATH, 'r', encoding='utf-8') as f:
            dataset_json = json.load(f)
        return dataset_json
    except json.JSONDecodeError as e:
        print(f"Error when decoding file JSON: {e}")
        raise
    except FileNotFoundError:
        print(f"❌ File '{DATASET_PATH}' not exists.")
        raise

def process_dataset_items(dataset_json):
    all_conversations = []

    for i, item in enumerate(tqdm(dataset_json, desc="processing items", unit="item")):
        image_id = item.get('image_id', f'unknown_id_{i}')
        base_caption = item.get("base_caption", "").strip()
        title = item.get("title", "").strip()
        content = item.get("content", "").strip()
        date = item.get("date", "").strip()
        label = item.get("label", "").strip()

        if not label:
            continue
        if not base_caption:
            continue
        if not title and not content:
            continue

        try:
            key_info = extract_key_info_from_json_item(content, title, date)
            user_prompt = create_advanced_prompt(base_caption, key_info)

            all_conversations.append([
                {"role": "user", "content": user_prompt},
                {"role": "assistant", "content": label},
            ])
        except Exception as e:
            print(f"Error when processing item {image_id}: {e}")
            import traceback
            traceback.print_exc()
            print(f"Data of item is error: base_caption='{base_caption[:50]}...', title='{title[:50]}...'")
            continue

    if not all_conversations:
        raise ValueError("There is no conversation.")

    print(f"Successfully process {len(all_conversations)} conversation examples from dataset.")
    return all_conversations

def apply_chat_template(all_conversations, tokenizer):
    if tokenizer.chat_template is None:
        print("⚠️ Tokenizer do not have available template chat")
        tokenizer.chat_template = (
            "{% if messages[0]['role'] == 'system' %}"
            "{% set loop_messages = messages[1:] %}"
            "{% set system_message = messages[0]['content'] %}"
            "{% else %}"
            "{% set loop_messages = messages %}"
            "{% set system_message = false %}"
            "{% endif %}"
            "{% for message in loop_messages %}"
            "{% if loop.index0 == 0 %}<|begin_of_sentence|>{% endif %}"
            "{% if message['role'] == 'user' %}"
            "<|User|>{{ message['content'] }}"
            "{% elif message['role'] == 'assistant' %}"
            "<|Assistant|>{{ message['content'] }}"
            "{% if loop.last %}<|end_of_sentence|>{% endif %}"
            "{% endif %}"
            "{% endfor %}"
        )

    formatted_texts = tokenizer.apply_chat_template(
        all_conversations,
        tokenize=False,
        add_generation_prompt=False,
    )

    if formatted_texts:
        print(formatted_texts[0][:1000] + "...")
    else:
        print("No data after applying template chat")
        raise ValueError("formatted_texts is empty. Check data processing and chat template application.")

    return formatted_texts

def create_huggingface_dataset(formatted_texts):
    df = pd.DataFrame({"text": formatted_texts})
    combined_dataset = Dataset.from_pandas(df)

    if len(combined_dataset) > 0:
        combined_dataset = combined_dataset.shuffle(seed=3407)
        print(combined_dataset[0]['text'][:500] + "...")

        combined_dataset.set_format(type="torch")

        return combined_dataset
    else:
        print("Dataset is empty after processing.")
        raise ValueError("Cannot proceed with an empty dataset.")

def prepare_training_data(tokenizer):

    dataset_json = load_dataset()


    all_conversations = process_dataset_items(dataset_json)


    formatted_texts = apply_chat_template(all_conversations, tokenizer)


    combined_dataset = create_huggingface_dataset(formatted_texts)

    return combined_dataset

In [None]:
import unsloth
from unsloth import FastLanguageModel
import torch


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_CONFIG["model_name"],
    max_seq_length=MODEL_CONFIG["max_seq_length"],
    load_in_4bit=MODEL_CONFIG["load_in_4bit"],
    load_in_8bit=MODEL_CONFIG["load_in_8bit"],
    full_finetuning=MODEL_CONFIG["full_finetuning"],
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = LORA_CONFIG['r'],
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = LORA_CONFIG['lora_alpha'],
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

In [None]:
combined_dataset = prepare_training_data(tokenizer)

In [None]:
import os
import torch
from trl import SFTConfig, SFTTrainer
def setup_training_args():
    training_args_sft = SFTConfig(
        output_dir=TRAINING_CONFIG["output_dir"],
        dataset_text_field=TRAINING_CONFIG["dataset_text_field"],
        max_seq_length=TRAINING_CONFIG["max_seq_length"],
        per_device_train_batch_size=TRAINING_CONFIG["per_device_train_batch_size"],
        gradient_accumulation_steps=TRAINING_CONFIG["gradient_accumulation_steps"],
        warmup_steps=TRAINING_CONFIG["warmup_steps"],
        num_train_epochs=TRAINING_CONFIG["num_train_epochs"],
        max_steps=TRAINING_CONFIG["max_steps"],
        learning_rate=TRAINING_CONFIG["learning_rate"],
        logging_steps=TRAINING_CONFIG["logging_steps"],
        optim=TRAINING_CONFIG["optim"],
        weight_decay=TRAINING_CONFIG["weight_decay"],
        lr_scheduler_type=TRAINING_CONFIG["lr_scheduler_type"],
        seed=TRAINING_CONFIG["seed"],
        report_to=TRAINING_CONFIG["report_to"],
        dataloader_num_workers=TRAINING_CONFIG["dataloader_num_workers"],
        save_strategy="steps",
        save_steps=10,
    )


    if torch.cuda.is_available():
        if torch.cuda.get_device_capability()[0] >= 8:
            training_args_sft.bf16 = True
            print("BF16 enabled for training.")
        else:  # Older GPUs
            training_args_sft.fp16 = True
            print("FP16 enabled for training.")


        training_args_sft.gradient_checkpointing = True
        print("Gradient checkpointing enabled.")

    return training_args_sft

def print_gpu_stats():
    if torch.cuda.is_available():
        gpu_stats = torch.cuda.get_device_properties(0)
        start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        max_memory = round(gpu_stats.total_memory / 1024**3, 3)
        print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
        print(f"{start_gpu_memory} GB of memory reserved before training.")
        return start_gpu_memory, max_memory
    else:
        print("CUDA is not available")
        return 0, 0

def print_training_stats(trainer_stats, start_gpu_memory, max_memory):
    if torch.cuda.is_available():
        used_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        used_memory_for_training = round(used_memory - start_gpu_memory, 3)
        used_percentage = round(used_memory / max_memory * 100, 3)
        training_percentage = round(used_memory_for_training / max_memory * 100, 3)
        print(f"\n{trainer_stats.metrics['train_runtime']:.2f} seconds used for training.")
        print(f"{trainer_stats.metrics['train_runtime']/60:.2f} minutes used for training.")
        print(f"Peak reserved memory = {used_memory} GB.")
        print(f"Peak reserved memory for training = {used_memory_for_training} GB.")
        print(f"Peak reserved memory % of max memory = {used_percentage} %.")
        print(f"Peak reserved memory for training % of max memory = {training_percentage} %.")
    else:  # CPU training
        print(f"\n{trainer_stats.metrics['train_runtime']:.2f} seconds used for training (CPU).")

def save_model_and_tokenizer(trainer, tokenizer, training_args_sft):
    final_model_path = os.path.join(training_args_sft.output_dir, "final_checkpoint")
    trainer.save_model(final_model_path)
    print(f"Model/Adapter is saved at: {final_model_path}")


    if hasattr(trainer.model, 'peft_config'):
        tokenizer.save_pretrained(final_model_path)
        print(f"Tokenizer is saved at: {final_model_path}")

def train_model(model, tokenizer, combined_dataset):
    if not torch.cuda.is_available() or len(combined_dataset) == 0:
        print("SFTTrainer is skipped due to the lack of CUDA or empty dataset.")
        return None


    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    # Setup training arguments
    training_args_sft = setup_training_args()
    # Force single process for dataset processing
    training_args_sft.dataset_num_proc = 1

    # Print GPU stats
    start_gpu_memory, max_memory = print_gpu_stats()


    import datasets
    import multiprocessing as mp

    # Force disable multiprocessing in datasets
    datasets.disable_caching()

    # Override multiprocessing functions to return 1
    original_cpu_count = mp.cpu_count
    def mock_cpu_count():
        return 1
    mp.cpu_count = mock_cpu_count

    # Patch datasets to use only 1 process
    if hasattr(datasets.config, 'DEFAULT_MAX_BATCH_SIZE'):
        datasets.config.DEFAULT_MAX_BATCH_SIZE = 1000

    # Force environment variable again
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    # Create trainer with explicit single-process settings
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=combined_dataset,
        args=training_args_sft,
        dataloader_num_workers=4,
        packing=False,
    )

    # Restore original function
    mp.cpu_count = original_cpu_count

    print("\nStart training the model...")

    if len(trainer.train_dataset) == 0:
        print("dataset is empty.")
        return None
    else:
        try:
            trainer_stats = trainer.train()

            # Print training statistics
            print_training_stats(trainer_stats, start_gpu_memory, max_memory)

            # Save model and tokenizer
            save_model_and_tokenizer(trainer, tokenizer, training_args_sft)

            return trainer

        except Exception as e:
            print(f"Error in training: {e}")
            import traceback
            traceback.print_exc()
            return None

def save_final_model(model, tokenizer, save_path="/content/drive/MyDrive/lora_model_512"):
    print(f"Saving the final model at: {save_path}")
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    print(f"Successfully save model at: {save_path}")



In [None]:
trainer = train_model(model, tokenizer, combined_dataset)
print("save the final model")
save_final_model(model, tokenizer)