In [None]:
from data_gatherer.data_gatherer import DataGatherer
from data_gatherer.llm.response_schema import *
import sklearn
import numpy as np
import pandas as pd
import os, time, re
from lxml import etree
import json

In [None]:
df_gt = pd.read_parquet("scripts/output/gold/dataset_citation_records_Table.parquet")

input_file = "scripts/exp_input/REV.txt"
batch_file_path=f'scripts/tmp/train_data_openai_RTR-3_DataRef-REV.jsonl'
output_batch_file = 'scripts/tmp/Train_results.jsonl'

excel_output_path = 'scripts/tmp/training_dataset_validation.xlsx'

model_name = "gpt-4o-mini" 
FDR = False
semantic_retrieval = True
brute_force_RegEx_ID_ptrs = True

top_k = 3
embeddings_retriever_model = None
dedup = True
prompt_name = "GPT_FewShot"


In [None]:
# write list to a text file
with open(input_file, 'r') as f:
    pmcids = f.read().splitlines()

print("Number of PMCIDs:", len(pmcids))

In [None]:
dg = DataGatherer(
    llm_name=model_name, 
    log_level='WARNING', 
    process_entire_document=FDR, 
    driver_path=None, 
    save_to_cache=False, 
    load_from_cache=False,
    embeds_cache_read=True,
    embeds_cache_write=True,
)

In [None]:
if os.path.exists('scripts/exp_input/Local_fulltext_pub_REV.parquet'):
    publication_fulltext_df = pd.read_parquet('scripts/exp_input/Local_fulltext_pub_REV.parquet')
    publication_fulltext = publication_fulltext_df.to_dict(orient='index')
else:
    publication_fulltext = dg.fetch_data(pmcids,write_df_to_path='scripts/exp_input/Local_fulltext_pub_REV.parquet')
    publication_fulltext_df = pd.DataFrame.from_dict(publication_fulltext, orient='index')

In [None]:
format_counts = {}
for url, data in publication_fulltext.items():
    if url not in pmcids:
        continue
    if data and 'raw_data_format' in data:
        fmt = data['raw_data_format']
        if fmt not in format_counts:
            format_counts[fmt] = {'count': 0, 'urls': []}
        format_counts[fmt]['count'] += 1
        format_counts[fmt]['urls'].append(url)
            
# Log format frequencies (counts only for readability)
frequency_summary = {fmt: info['count'] for fmt, info in format_counts.items()}
dg.logger.info(f"Fetched {len(publication_fulltext)} Papers. Format frequencies: {frequency_summary}")

In [None]:
batch_requests, cnt, last_url_raw_data_format = [], 0, False

for url_raw_data_format, vals in format_counts.items():
    for url in vals['urls']:

        msg_already_added = ''
        
        data = publication_fulltext[url]
        dg.logger.info(f"type of data: {type(data['fetched_data'])}")

        if isinstance(data['fetched_data'], str) and url_raw_data_format.upper() == 'XML':
            dg.logger.info("string data is not supported input, need etree")
            data['fetched_data'] = etree.fromstring(data['fetched_data'].encode('utf-8'))

        try:                        
            if cnt != 0 and url_raw_data_format == last_url_raw_data_format:
                dg.logger.info(f"Reusing existing parser of name: {dg.parser.__class__.__name__}")
            else:
                dg.logger.info(f"Creating new parser for format: {url_raw_data_format}")
                dg.init_parser_by_input_type(url_raw_data_format, data['fetched_data'], embeddings_retriever_model)
                        
            # Generate unique custom_id
            article_id = dg.url_to_page_id(url)
            pmcid = dg.data_fetcher.url_to_pmcid(url)
            timestamp = int(time.time() * 1000)
            custom_id = f"{dg.llm}_{article_id}_{timestamp}"
            custom_id = re.sub(r'[^a-zA-Z0-9_-]', '_', custom_id)[:64]
                        
            if dg.full_document_read:
                dg.logger.info(f'normalize input')
                if url_raw_data_format.upper() == 'XML':
                    normalized_input = (dg.parser.normalize_XML(data['fetched_data']) 
                                        if hasattr(dg.parser, 'normalize_XML') 
                                        else data['fetched_data'])
                elif url_raw_data_format.upper() == 'HTML':
                    normalized_input = (dg.parser.normalize_HTML(data['fetched_data']) 
                                                if hasattr(dg.parser, 'normalize_HTML') 
                                                else data['fetched_data'])
                elif url_raw_data_format.upper() == 'PDF':
                    normalized_input = data['fetched_data']
                else:
                    raise ValueError(f"Unsupported raw data format: {url_raw_data_format}")
                        
            else:
                dg.logger.info(f'relevant section retrieval')
                data_availability_obj = dg.parser.retrieve_relevant_content(
                                data['fetched_data'],
                                semantic_retrieval=semantic_retrieval,
                                top_k=top_k,
                                skip_rule_based_retrieved_elm=dedup,
                                include_snippets_with_ID_patterns=brute_force_RegEx_ID_ptrs,
                                article_id=dg.data_fetcher.url_to_pmcid(url),
                                output_format='json'
                            )
                dg.logger.info(f"type of data_availability_obj: {type(data_availability_obj)}")
                dg.logger.info(f"length of data_availability_obj: {len(data_availability_obj)}")
                dg.logger.info(f"data_availability_obj content: {data_availability_obj}")

                for idx, obj in enumerate(data_availability_obj):
                    dg.logger.info(f"Object type in data_availability_obj: {type(obj)}")

                    # Extract text and metadata from obj
                    if isinstance(obj, dict) and 'text' in obj:
                        normalized_input = obj['text']
                        # Preserve all other attributes in metadata (excluding 'text')
                        obj_metadata = {k: v for k, v in obj.items() if k != 'text'}
                    elif isinstance(obj, str):
                        normalized_input = obj
                        obj_metadata = {}
                    else:
                        dg.logger.warning(f"Unsupported object type in data_availability_obj: {type(obj)}")
                        continue

                    # Render prompt using the correct parser
                    static_prompt = dg.parser.prompt_manager.load_prompt(prompt_name)
                    messages = dg.parser.prompt_manager.render_prompt(
                                    static_prompt,
                                    entire_doc=dg.full_document_read,
                                    content=normalized_input,
                                    repos=', '.join(dg.parser.repo_names) if hasattr(dg.parser, 'repo_names') else '',
                                    url=url
                                )
                    
                    # Create unique custom_id for each snippet
                    snippet_custom_id = f"{custom_id}_snippet_{idx}"
                                
                    # Create batch request for LLMClient
                    batch_request = {
                                    'custom_id': snippet_custom_id,
                                    'messages': messages,
                                    'metadata': {
                                        'url': url,
                                        'article_id': article_id,
                                        'raw_data_format': url_raw_data_format,
                                        'snippet_index': idx,
                                        **obj_metadata  # Preserve all attributes from obj
                                    }
                                }
                                
                    batch_requests.append(batch_request)
                        
        except Exception as e:
            dg.logger.error(f"Error preparing request for {url}: {e}")
            continue

        last_url_raw_data_format = url_raw_data_format
        cnt+=1
            
dg.logger.info(f"Prepared {len(batch_requests)} batch requests")

In [None]:
len(batch_requests)

In [None]:
batch_result = dg.parser.llm_client._handle_batch_mode(
                batch_requests=batch_requests,
                batch_file_path=batch_file_path,
                temperature=0,
                response_format=dataset_response_schema_gpt,
                api_provider='openai'
                )
            
result = {
    'batch_file_created': batch_result,
    'fetched_data_count': len(publication_fulltext),
    'processed_requests': len(batch_requests),
    'api_provider': 'openai',
    'model': dg.llm
    }

In [None]:
result

In [None]:
prompts_filepath = batch_file_path
prompts_load = []
with open(prompts_filepath, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            prompts_load.append(json.loads(line))
len(prompts_load)

In [None]:
found_avg, bad_input, empty_prompt = 0, 0, []
n_dfs = len(prompts_load)

for prompt in prompts_load:
    pmc_id = dg.data_fetcher.url_to_pmcid(prompt['custom_id'])
    gt = df_gt[df_gt['pmcid'] == pmc_id]
    datasets_gt = gt['identifier'].values.tolist()
    #print(f"datasets: {datasets_gt}")
    body_msg = [item['content'] for item in prompt['body']['input']]
    #print (f"prompt: {body_msg}")
    input_cont_str = "\n".join(body_msg)

    datasets_found, datasets_tot = 0, len(datasets_gt)
    contains_one = False
    for dataset in datasets_gt:
        if dataset.lower() in input_cont_str.lower():
            datasets_found += 1
            contains_one = True
            continue
        else:
            print(f"Missing dataset {dataset} in prompt {prompt['custom_id']} for pmcid {pmc_id}")
    if not contains_one:
        bad_input += 1
        empty_prompt.append(prompt['custom_id'])
    found_i = datasets_found / datasets_tot if datasets_tot > 0 else 1.0
    found_avg += found_i/n_dfs

In [None]:
found_avg, bad_input

In [None]:
# Simple chunking and submission - NO monitoring or result combination
result = dg.split_jsonl_and_submit(
    batch_file_path=batch_file_path,
    max_file_size_mb=200.0,
    api_provider='openai',
    wait_between_submissions=30,
    batch_description=f"Training Dataset Creation"
)

In [None]:
batch_train = 'batch_691bf0fcc7d081909d396c54e2082ced'

In [None]:
if not dg.parser:
    dg.init_parser_by_input_type('XML')

res = dg.parser.llm_client.download_batch_results(
    batch_id=batch_train,
    output_file_path=output_batch_file,
    api_provider='openai'
)

In [None]:
with open(output_batch_file, 'r') as f:
    lines = f.readlines()
print(f"Number of lines in combined file: {len(lines)}")

In [None]:
ret_file = 'scripts/output/semantic_search/Train_results.csv'
ret_file

In [None]:
if not dg.parser:
    dg.init_parser_by_input_type('XML')

res_df = dg.from_batch_resp_file_to_df(output_batch_file, output_file_path=ret_file)

In [None]:
res_df = pd.read_csv(ret_file)
len(res_df)

In [None]:
pmcids_ret = set([re.sub('(https://www.ncbi.nlm.nih.gov/pmc/articles/.*)/','\\1',item).lower() for item in res_df['source_url'].to_list()])
pmcids = set([idx.lower() for idx in pmcids])
missing_urls = list(pmcids - pmcids_ret)
len(missing_urls)

In [None]:
new_datasets_append = dg.process_articles(
    missing_urls,
    prompt_name="GPT_FewShot",
    full_document_read=FDR,
    top_k = top_k,
    semantic_retrieval=semantic_retrieval,
    response_format=dataset_response_schema_gpt
)

In [None]:
len(new_datasets_append), type(new_datasets_append)

In [None]:
# union dataframes
for pmc_link in new_datasets_append.keys():
    final_df = pd.concat([res_df, new_datasets_append[pmc_link]], ignore_index=True)

In [None]:
final_df.to_csv(ret_file, index=False)

In [None]:
res_df = pd.read_csv(ret_file)
res_df.head(2)

In [None]:
len(res_df)

In [None]:
# train_dict creation
train_dict = {}

# Load prompts from the request file
prompts_file = batch_file_path
prompts_data = {}

with open(prompts_file, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            prompt = json.loads(line)
            custom_id = prompt['custom_id']
            
            # Extract content values from body.input array
            input_messages = prompt['body']['input']
            content_values = [msg['content'] for msg in input_messages if 'content' in msg]
            just_context = content_values[1][1016:]
            
            # Store in prompts_data
            prompts_data[custom_id] = {
                'input_content': just_context,
                'metadata': prompt.get('metadata', {})
            }

#print(f"Loaded {len(prompts_data)} prompts from request file")

# Load results from the results file
results_file = output_batch_file
results_data = {}

with open(results_file, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            result = json.loads(line)
            custom_id = result['custom_id']
            
            # Extract output.text from response
            output_text = None
            if 'response' in result and 'body' in result['response']:
                body = result['response']['body']
                # Navigate: body -> output[0] -> content[0] -> text
                if 'output' in body and len(body['output']) > 0:
                    output_item = body['output'][0]
                    if 'content' in output_item and len(output_item['content']) > 0:
                        content_item = output_item['content'][0]
                        if 'text' in content_item:
                            output_text = content_item['text']
            
            results_data[custom_id] = {
                'output_text': output_text
            }

#print(f"Loaded {len(results_data)} results from results file")

# Combine prompts and results into train_dict
for custom_id in prompts_data.keys():
    if custom_id in results_data:
        train_dict[custom_id] = {
            'custom_id': custom_id,
            'input_content': prompts_data[custom_id]['input_content'],
            'output_text': results_data[custom_id]['output_text'],
            'metadata': prompts_data[custom_id]['metadata']
        }
    else:
        print(f"Warning: No result found for custom_id: {custom_id}")

print(f"\nCreated train_dict with {len(train_dict)} entries")
print(f"Missing results: {len(prompts_data) - len(train_dict)}")

In [None]:
iter_n = 0
for custom_id, entry in train_dict.items():
    print(custom_id)
    print(entry['metadata'])
    print(entry['input_content'])
    print(entry['output_text'])
    iter_n += 1
    if iter_n==3:
        break
    else:
        continue

## Dataset Validation
Convert train_dict to DataFrame for validation and export

In [None]:
# Convert train_dict to DataFrame for validation
validation_records = []

for custom_id, entry in train_dict.items():
    # Join input content into a single string for readability
    input_text = "\n---\n".join(entry['input_content']) if isinstance(entry['input_content'], list) else str(entry['input_content'])
    
    record = {
        'custom_id': entry['custom_id'],
        'url': entry['metadata'].get('url', ''),
        'article_id': entry['metadata'].get('article_id', ''),
        'raw_data_format': entry['metadata'].get('raw_data_format', ''),
        'snippet_index': entry['metadata'].get('snippet_index', ''),
        'section_title': entry['metadata'].get('section_title', ''),
        'sec_type': entry['metadata'].get('sec_type', ''),
        'L2_distance': entry['metadata'].get('L2_distance', ''),
        'input_text': input_text,
        'output_text': entry['output_text'],
        'validated': '',  # Empty column for manual validation
        'notes': ''  # Empty column for notes
    }
    validation_records.append(record)

validation_df = pd.DataFrame(validation_records)
print(f"Created validation DataFrame with {len(validation_df)} rows")
validation_df.head()

In [None]:
# Export to Excel for validation (recommended - preserves formatting and allows filtering)
validation_df.to_excel(excel_output_path, index=False, engine='openpyxl')
print(f"Exported to Excel: {excel_output_path}")

print("\nYou can now:")
print("1. Open the Excel file to validate with filtering, sorting, and cell-by-cell editing")
print("2. Open the CSV file in any spreadsheet application")
print("3. Use the DataFrame below for in-notebook validation")

## Smart Duplicate Detection & Validation

This section provides an intelligent duplicate detection system that:
1. Identifies rows with identical `output_text` values
2. Compares `input_text` using word frequency analysis
3. Shows differences and prompts for manual validation
4. Helps you efficiently curate your training dataset

In [None]:
from collections import Counter
import json
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets

def count_word_frequencies(text):
    """Count word frequencies in text, case-insensitive."""
    if pd.isna(text):
        return Counter()
    # Simple word tokenization: split by whitespace and remove punctuation
    words = str(text).lower().split()
    # Remove common punctuation
    cleaned_words = [''.join(c for c in word if c.isalnum() or c in '-_') for word in words]
    cleaned_words = [w for w in cleaned_words if w]  # Remove empty strings
    return Counter(cleaned_words)

def compare_word_frequencies(freq1, freq2):
    """
    Compare two word frequency counters and return differences.
    Returns dict with words that have different frequencies.
    """
    all_words = set(freq1.keys()) | set(freq2.keys())
    differences = {}
    
    for word in all_words:
        count1 = freq1.get(word, 0)
        count2 = freq2.get(word, 0)
        if count1 != count2:
            differences[word] = {'current': count1, 'previous': count2}
    
    return differences

def format_comparison_html(current_row, past_row, word_diffs, idx_current, idx_past):
    """
    Create an HTML formatted comparison between two rows.
    """
    html = f"""
    <div style="font-family: monospace; border: 2px solid #333; padding: 20px; margin: 10px 0; background: #f9f9f9;">
        <h3 style="color: #d63031;">üîç Potential Duplicate Detected</h3>
        <hr>
        
        <h4 style="color: #0984e3;">Current Row (Index: {idx_current})</h4>
        <table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
            <tr style="background: #dfe6e9;">
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold; width: 150px;">Custom ID</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{current_row['custom_id']}</td>
            </tr>
            <tr>
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">URL</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{current_row['url']}</td>
            </tr>
            <tr style="background: #dfe6e9;">
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">Article ID</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{current_row['article_id']}</td>
            </tr>
            <tr>
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">Section</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{current_row['section_title']} ({current_row['sec_type']})</td>
            </tr>
            <tr style="background: #dfe6e9;">
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">L2 Distance</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{current_row['L2_distance']}</td>
            </tr>
        </table>
        
        <h4 style="color: #6c5ce7;">Previous Row (Index: {idx_past})</h4>
        <table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
            <tr style="background: #dfe6e9;">
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold; width: 150px;">Custom ID</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{past_row['custom_id']}</td>
            </tr>
            <tr>
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">URL</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{past_row['url']}</td>
            </tr>
            <tr style="background: #dfe6e9;">
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">Article ID</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{past_row['article_id']}</td>
            </tr>
            <tr>
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">Section</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{past_row['section_title']} ({past_row['sec_type']})</td>
            </tr>
            <tr style="background: #dfe6e9;">
                <td style="padding: 8px; border: 1px solid #b2bec3; font-weight: bold;">L2 Distance</td>
                <td style="padding: 8px; border: 1px solid #b2bec3;">{past_row['L2_distance']}</td>
            </tr>
        </table>
        
        <hr>
        <h4 style="color: #e17055;">üìä Word Frequency Differences in Input Text</h4>
        <p style="color: #636e72;">Words appearing different number of times between current and previous input texts:</p>
        <table style="width: 100%; border-collapse: collapse; margin: 10px 0;">
            <tr style="background: #2d3436; color: white;">
                <th style="padding: 8px; border: 1px solid #000;">Word</th>
                <th style="padding: 8px; border: 1px solid #000;">Current Count</th>
                <th style="padding: 8px; border: 1px solid #000;">Previous Count</th>
                <th style="padding: 8px; border: 1px solid #000;">Difference</th>
            </tr>
    """
    
    # Sort by absolute difference
    sorted_diffs = sorted(word_diffs.items(), 
                          key=lambda x: abs(x[1]['current'] - x[1]['previous']), 
                          reverse=True)
    
    for i, (word, counts) in enumerate(sorted_diffs[:50]):  # Show top 50 differences
        bg_color = "#dfe6e9" if i % 2 == 0 else "#ffffff"
        diff = counts['current'] - counts['previous']
        diff_color = "#00b894" if diff > 0 else "#d63031"
        html += f"""
            <tr style="background: {bg_color};">
                <td style="padding: 8px; border: 1px solid #b2bec3;"><code>{word}</code></td>
                <td style="padding: 8px; border: 1px solid #b2bec3; text-align: center;">{counts['current']}</td>
                <td style="padding: 8px; border: 1px solid #b2bec3; text-align: center;">{counts['previous']}</td>
                <td style="padding: 8px; border: 1px solid #b2bec3; text-align: center; color: {diff_color}; font-weight: bold;">{diff:+d}</td>
            </tr>
        """
    
    if len(sorted_diffs) > 50:
        html += f"""
            <tr>
                <td colspan="4" style="padding: 8px; text-align: center; font-style: italic; color: #636e72;">
                    ... and {len(sorted_diffs) - 50} more word differences
                </td>
            </tr>
        """
    
    html += """
        </table>
        
        <hr>
        <h4 style="color: #00b894;">üìÑ Input Text Comparison</h4>
    """
    
    # Show first 500 chars of each input text
    current_input = str(current_row['input_text'])[:500]
    past_input = str(past_row['input_text'])[:500]
    
    html += f"""
        <div style="margin: 10px 0;">
            <p style="font-weight: bold; color: #0984e3;">Current Input Text (first 500 chars):</p>
            <pre style="background: #ecf0f1; padding: 10px; border-left: 4px solid #0984e3; overflow-x: auto;">{current_input}...</pre>
        </div>
        <div style="margin: 10px 0;">
            <p style="font-weight: bold; color: #6c5ce7;">Previous Input Text (first 500 chars):</p>
            <pre style="background: #ecf0f1; padding: 10px; border-left: 4px solid #6c5ce7; overflow-x: auto;">{past_input}...</pre>
        </div>
        
        <hr>
        <h4 style="color: #fdcb6e;">üéØ Output Text (Identical)</h4>
        <pre style="background: #fffbea; padding: 10px; border-left: 4px solid #fdcb6e; overflow-x: auto;">{current_row['output_text']}</pre>
    </div>
    """
    
    return html

print("‚úÖ Helper functions loaded successfully!")

In [None]:
# Load the validation dataset
validation_df = pd.read_excel(excel_output_path)
print(f"Loaded {len(validation_df)} rows from {excel_output_path}")
print(f"Columns: {validation_df.columns.tolist()}")

# Initialize validation column if it doesn't exist
if 'validated' not in validation_df.columns:
    validation_df['validated'] = None
if 'is_duplicate' not in validation_df.columns:
    validation_df['is_duplicate'] = None

In [None]:
# Build duplicate detection index with smart filtering
output_text_index = {}  # Maps output_text -> list of row indices
duplicate_pairs = []  # List of (current_idx, past_idx) tuples to review

print("üîç Building duplicate detection index...")

# First, identify "no dataset" patterns
no_dataset_patterns = ['n/a', 'not applicable', 'no dataset', 'none']

def is_no_dataset_response(output_text):
    """Check if output indicates no dataset found."""
    if pd.isna(output_text):
        return True
    text_lower = str(output_text).lower()
    # Check if it's a JSON with n/a values
    if 'dataset_identifier' in text_lower and '"n/a"' in text_lower:
        return True
    # Check for other no-dataset indicators
    return any(pattern in text_lower for pattern in no_dataset_patterns)

no_dataset_rows = []
real_duplicate_pairs = []

for idx, row in validation_df.iterrows():
    output_text = row['output_text']
    
    # Separate handling for "no dataset" responses
    if is_no_dataset_response(output_text):
        no_dataset_rows.append(idx)
        continue
    
    # For real dataset responses, track duplicates
    if output_text in output_text_index:
        # Found a duplicate output_text!
        # Only compare with the FIRST occurrence to avoid explosion
        # (if you want all pairs, we can add that as an option)
        first_idx = output_text_index[output_text][0]
        real_duplicate_pairs.append((idx, first_idx))
        
        # Add current index to the list
        output_text_index[output_text].append(idx)
    else:
        # First occurrence
        output_text_index[output_text] = [idx]

# Summary
duplicate_pairs = real_duplicate_pairs
no_dataset_count = len(no_dataset_rows)

print(f"\n‚úÖ Analysis complete!")
print(f"üìä No-dataset responses: {no_dataset_count} rows ({no_dataset_count/len(validation_df)*100:.1f}%)")
print(f"üìä Real dataset responses: {len(validation_df) - no_dataset_count} rows")
print(f"\nüîç Duplicate detection (real datasets only):")
print(f"   Unique output_text values: {len(output_text_index)}")
print(f"   Duplicate pairs to review: {len(duplicate_pairs)}")
print(f"   Total rows with duplicate output_text: {sum(1 for v in output_text_index.values() if len(v) > 1)}")

print(f"\nüí° Tip: {no_dataset_count} 'no dataset' rows were excluded from duplicate review.")
print(f"   You can mark these as valid automatically or review them separately.")

### Handle "No Dataset" Responses

The cells below help you decide what to do with rows where no datasets were found.

In [None]:
# Analyze "no dataset" responses
print("üìä 'No Dataset' Response Analysis")
print("=" * 60)

# Get all no-dataset rows
no_dataset_df = validation_df.iloc[no_dataset_rows]

# Group by output_text pattern
no_dataset_groups = no_dataset_df.groupby('output_text').size().sort_values(ascending=False)

print(f"\nTotal 'no dataset' responses: {len(no_dataset_rows)}")
print(f"Unique patterns: {len(no_dataset_groups)}")

print(f"\nüìã Top 5 patterns:")
for output_text, count in no_dataset_groups.head(5).items():
    output_preview = str(output_text)[:100] + "..." if len(str(output_text)) > 100 else str(output_text)
    print(f"\n  Pattern appears {count} times:")
    print(f"  {output_preview}")

print(f"\nüí° Options:")
print(f"  1. Keep ONE sample of each pattern for training (recommended)")
print(f"  2. Keep ALL no-dataset responses (for training variety)")
print(f"  3. Remove ALL no-dataset responses (if not needed)")
print(f"\nRun the appropriate cell below based on your choice.")

In [None]:
# Keep ALL "no dataset" responses for training variety
# These are already excluded from duplicate review, so no action needed!

print(f"üìä No-dataset response handling:")
print(f"   Total no-dataset responses: {len(no_dataset_rows)}")
print(f"   Action: KEEPING ALL for training variety")
print(f"\n‚úÖ All {len(no_dataset_rows)} no-dataset responses will be kept in the dataset")
print(f"\nüí° Note: These rows are already excluded from duplicate review.")
print(f"   You will only review duplicates among real dataset extractions.")
print(f"\nüéØ No changes made to the dataset.")

In [None]:
# Interactive duplicate review widget
class DuplicateReviewer:
    def __init__(self, df, duplicate_pairs):
        self.df = df
        self.duplicate_pairs = duplicate_pairs
        self.current_pair_idx = 0
        self.decisions = {}  # Maps (current_idx, past_idx) -> decision
        
        # Create widgets
        self.output_area = widgets.Output()
        
        # New: buttons to mark which one is duplicate
        self.mark_current_dup_btn = widgets.Button(
            description="üî¥ Current is Duplicate",
            button_style='danger',
            layout=widgets.Layout(width='220px', height='50px')
        )
        self.mark_past_dup_btn = widgets.Button(
            description="üî¥ Previous is Duplicate",
            button_style='danger',
            layout=widgets.Layout(width='220px', height='50px')
        )
        self.mark_both_dup_btn = widgets.Button(
            description="üî¥ Both are Duplicates",
            button_style='danger',
            layout=widgets.Layout(width='220px', height='50px')
        )
        
        self.not_duplicate_btn = widgets.Button(
            description="‚úÖ Not Duplicates",
            button_style='success',
            layout=widgets.Layout(width='200px', height='50px')
        )
        self.skip_btn = widgets.Button(
            description="‚è≠Ô∏è Skip for Now",
            button_style='info',
            layout=widgets.Layout(width='200px', height='50px')
        )
        self.prev_btn = widgets.Button(
            description="‚¨ÖÔ∏è Previous",
            button_style='warning',
            layout=widgets.Layout(width='150px', height='50px')
        )
        self.save_btn = widgets.Button(
            description="üíæ Save Progress",
            button_style='primary',
            layout=widgets.Layout(width='200px', height='50px')
        )
        
        # Progress label
        self.progress_label = widgets.HTML()
        
        # Bind button actions
        self.mark_current_dup_btn.on_click(self.mark_current_duplicate)
        self.mark_past_dup_btn.on_click(self.mark_past_duplicate)
        self.mark_both_dup_btn.on_click(self.mark_both_duplicate)
        self.not_duplicate_btn.on_click(self.mark_not_duplicate)
        self.skip_btn.on_click(self.skip)
        self.prev_btn.on_click(self.previous)
        self.save_btn.on_click(self.save_progress)
        
    def show_current_pair(self):
        """Display the current duplicate pair for review."""
        with self.output_area:
            clear_output(wait=True)
            
            if self.current_pair_idx >= len(self.duplicate_pairs):
                print("üéâ All duplicate pairs reviewed!")
                print(f"\nüìä Review Summary:")
                current_dup = sum(1 for v in self.decisions.values() if v == 'current_duplicate')
                past_dup = sum(1 for v in self.decisions.values() if v == 'past_duplicate')
                both_dup = sum(1 for v in self.decisions.values() if v == 'both_duplicate')
                not_dup = sum(1 for v in self.decisions.values() if v == 'not_duplicate')
                skipped = len(self.duplicate_pairs) - len(self.decisions)
                
                print(f"   - Current marked as duplicate: {current_dup}")
                print(f"   - Previous marked as duplicate: {past_dup}")
                print(f"   - Both marked as duplicates: {both_dup}")
                print(f"   - Not duplicates: {not_dup}")
                print(f"   - Skipped: {skipped}")
                print(f"\n   Total rows to remove: {current_dup + past_dup + (both_dup * 2)}")
                print("\nüíæ Don't forget to save your progress!")
                return
            
            current_idx, past_idx = self.duplicate_pairs[self.current_pair_idx]
            current_row = self.df.iloc[current_idx]
            past_row = self.df.iloc[past_idx]
            
            # Update progress
            self.progress_label.value = f"<h3 style='color: #2c3e50;'>Progress: {self.current_pair_idx + 1} / {len(self.duplicate_pairs)}</h3>"
            
            # Compute word frequency differences
            current_freq = count_word_frequencies(current_row['input_text'])
            past_freq = count_word_frequencies(past_row['input_text'])
            word_diffs = compare_word_frequencies(current_freq, past_freq)
            
            # Show comparison
            html = format_comparison_html(current_row, past_row, word_diffs, current_idx, past_idx)
            display(HTML(html))
            
            # Show any previous decision
            pair_key = (current_idx, past_idx)
            if pair_key in self.decisions:
                prev_decision = self.decisions[pair_key]
                if prev_decision == 'current_duplicate':
                    print(f"\n‚ö†Ô∏è Previous decision: CURRENT marked as duplicate")
                elif prev_decision == 'past_duplicate':
                    print(f"\n‚ö†Ô∏è Previous decision: PREVIOUS marked as duplicate")
                elif prev_decision == 'both_duplicate':
                    print(f"\n‚ö†Ô∏è Previous decision: BOTH marked as duplicates")
                elif prev_decision == 'not_duplicate':
                    print(f"\n‚ö†Ô∏è Previous decision: NOT DUPLICATES")
    
    def mark_current_duplicate(self, b):
        """Mark current row as duplicate and move to next."""
        if self.current_pair_idx < len(self.duplicate_pairs):
            pair = self.duplicate_pairs[self.current_pair_idx]
            self.decisions[pair] = 'current_duplicate'
            current_idx, _ = pair
            self.df.loc[current_idx, 'is_duplicate'] = True
            self.df.loc[current_idx, 'notes'] = 'Duplicate (current row)'
            self.current_pair_idx += 1
            self.show_current_pair()
    
    def mark_past_duplicate(self, b):
        """Mark past row as duplicate and move to next."""
        if self.current_pair_idx < len(self.duplicate_pairs):
            pair = self.duplicate_pairs[self.current_pair_idx]
            self.decisions[pair] = 'past_duplicate'
            _, past_idx = pair
            self.df.loc[past_idx, 'is_duplicate'] = True
            self.df.loc[past_idx, 'notes'] = 'Duplicate (previous row)'
            self.current_pair_idx += 1
            self.show_current_pair()
    
    def mark_both_duplicate(self, b):
        """Mark both rows as duplicates and move to next."""
        if self.current_pair_idx < len(self.duplicate_pairs):
            pair = self.duplicate_pairs[self.current_pair_idx]
            self.decisions[pair] = 'both_duplicate'
            current_idx, past_idx = pair
            self.df.loc[current_idx, 'is_duplicate'] = True
            self.df.loc[current_idx, 'notes'] = 'Duplicate (both rows)'
            self.df.loc[past_idx, 'is_duplicate'] = True
            self.df.loc[past_idx, 'notes'] = 'Duplicate (both rows)'
            self.current_pair_idx += 1
            self.show_current_pair()
    
    def mark_not_duplicate(self, b):
        """Mark pair as not duplicate and move to next."""
        if self.current_pair_idx < len(self.duplicate_pairs):
            pair = self.duplicate_pairs[self.current_pair_idx]
            self.decisions[pair] = 'not_duplicate'
            current_idx, past_idx = pair
            self.df.loc[current_idx, 'is_duplicate'] = False
            self.df.loc[past_idx, 'is_duplicate'] = False
            self.current_pair_idx += 1
            self.show_current_pair()
    
    def skip(self, b):
        """Skip current pair and move to next."""
        self.current_pair_idx += 1
        self.show_current_pair()
    
    def previous(self, b):
        """Go back to previous pair."""
        if self.current_pair_idx > 0:
            self.current_pair_idx -= 1
            self.show_current_pair()
    
    def save_progress(self, b):
        """Save current progress to Excel file."""
        with self.output_area:
            print("\nüíæ Saving progress...")
            try:
                self.df.to_excel(excel_output_path, index=False, engine='openpyxl')
                print(f"‚úÖ Progress saved to {excel_output_path}")
                
                # Save decisions log
                decisions_log_path = excel_output_path.replace('.xlsx', '_duplicate_decisions.json')
                with open(decisions_log_path, 'w') as f:
                    # Convert tuple keys to strings for JSON serialization
                    json_decisions = {f"{k[0]}_{k[1]}": v for k, v in self.decisions.items()}
                    json.dump(json_decisions, f, indent=2)
                print(f"‚úÖ Decisions log saved to {decisions_log_path}")
            except Exception as e:
                print(f"‚ùå Error saving: {e}")
    
    def display(self):
        """Display the review interface."""
        self.show_current_pair()
        
        # Row 1: Mark duplicate buttons
        dup_button_box = widgets.HBox(
            [self.mark_current_dup_btn, self.mark_past_dup_btn, self.mark_both_dup_btn],
            layout=widgets.Layout(justify_content='center', margin='10px 0')
        )
        
        # Row 2: Other action buttons
        action_button_box = widgets.HBox(
            [self.not_duplicate_btn, self.skip_btn],
            layout=widgets.Layout(justify_content='center', margin='10px 0')
        )
        
        # Row 3: Control buttons
        control_box = widgets.HBox(
            [self.prev_btn, self.save_btn],
            layout=widgets.Layout(justify_content='center', margin='10px 0')
        )
        
        display(self.progress_label)
        display(self.output_area)
        display(dup_button_box)
        display(action_button_box)
        display(control_box)

print("‚úÖ DuplicateReviewer class loaded!")

### Start Duplicate Review

Run the cell below to start the interactive duplicate review process.

**How to use:**
- Review each pair carefully
- Check the **word frequency differences** to see how input texts differ
- Read the **input text comparison** (first 500 characters shown)
- The **output text is identical** for both rows (that's why they're flagged)
- Decide which row(s) to keep or remove:
  - üî¥ **Current is Duplicate**: Mark the current (top) row as duplicate - will be removed
  - üî¥ **Previous is Duplicate**: Mark the previous (bottom) row as duplicate - will be removed
  - üî¥ **Both are Duplicates**: Mark both rows as duplicates - both will be removed
  - ‚úÖ **Not Duplicates**: Keep both rows - they have different contexts despite same output
  - ‚è≠Ô∏è **Skip for Now**: Unsure, will review later
- Use ‚¨ÖÔ∏è **Previous** to go back and review earlier pairs
- Click üíæ **Save Progress** frequently to preserve your work!

**Tip:** Usually you'll want to mark just one as duplicate (keep the better quality one), but sometimes both might be poor quality and you'll want to mark both.

In [None]:
# Start the duplicate review process
if len(duplicate_pairs) > 0:
    reviewer = DuplicateReviewer(validation_df, duplicate_pairs)
    reviewer.display()
else:
    print("‚úÖ No duplicate pairs found! Your dataset is clean.")

### Post-Review Analysis

After reviewing duplicates, use the cells below to analyze and clean your dataset.

In [None]:
# Reload data if you're returning to the notebook after saving
# validation_df = pd.read_excel(excel_output_path)

# Show duplicate review statistics
print("üìä Duplicate Review Statistics")
print("=" * 60)

if 'is_duplicate' in validation_df.columns:
    duplicate_counts = validation_df['is_duplicate'].value_counts(dropna=False)
    print(f"Marked as duplicates: {duplicate_counts.get(True, 0)}")
    print(f"Marked as not duplicates: {duplicate_counts.get(False, 0)}")
    print(f"Not yet reviewed: {duplicate_counts.get(None, 0) if None in duplicate_counts.index else validation_df['is_duplicate'].isna().sum()}")
    print(f"\nTotal rows: {len(validation_df)}")
    
    # Show rows marked as duplicates
    duplicates = validation_df[validation_df['is_duplicate'] == True]
    if len(duplicates) > 0:
        print(f"\nüîç Rows marked as duplicates ({len(duplicates)} rows):")
        print(duplicates[['custom_id', 'article_id', 'section_title', 'output_text']].head(10))
else:
    print("No duplicate review has been performed yet.")

In [None]:
# Remove duplicates from dataset
# Run this cell after completing your duplicate review

if 'is_duplicate' in validation_df.columns:
    # Create clean dataset by removing rows marked as duplicates
    clean_df = validation_df[validation_df['is_duplicate'] != True].copy()
    
    print(f"üìä Dataset Cleaning Summary")
    print("=" * 60)
    print(f"Original dataset size: {len(validation_df)} rows")
    print(f"Rows marked as duplicates: {(validation_df['is_duplicate'] == True).sum()}")
    print(f"Clean dataset size: {len(clean_df)} rows")
    print(f"Reduction: {len(validation_df) - len(clean_df)} rows ({((len(validation_df) - len(clean_df)) / len(validation_df) * 100):.2f}%)")
    
    # Save clean dataset
    clean_output_path = excel_output_path.replace('.xlsx', '_clean.xlsx')
    clean_df.to_excel(clean_output_path, index=False, engine='openpyxl')
    print(f"\n‚úÖ Clean dataset saved to: {clean_output_path}")
    
    # Update validation_df to use clean version
    # validation_df = clean_df
else:
    print("‚ö†Ô∏è No duplicate review has been performed yet. Run the duplicate detection cells first.")