In [None]:
from data_gatherer.data_gatherer import DataGatherer
from data_gatherer.llm.response_schema import *
import sklearn
import numpy as np
import pandas as pd
import os, time, re
from lxml import etree
import json

In [None]:
df_gt = pd.read_parquet("scripts/output/gold/dataset_citation_records_Table.parquet")

input_file = "scripts/exp_input/REV.txt"
batch_file_path=f'scripts/tmp/train_data_openai_RTR-3_DataRef-REV.jsonl'
output_batch_file = 'scripts/tmp/Train_results.jsonl'

excel_output_path = 'scripts/tmp/training_dataset_validation.xlsx'

model_name = "gpt-4o-mini" 
FDR = False
semantic_retrieval = True
brute_force_RegEx_ID_ptrs = True

top_k = 3
embeddings_retriever_model = None
dedup = True
prompt_name = "GPT_FewShot"


In [None]:
# write list to a text file
with open(input_file, 'r') as f:
    pmcids = f.read().splitlines()

print("Number of PMCIDs:", len(pmcids))

In [None]:
dg = DataGatherer(
    llm_name=model_name, 
    log_level='WARNING', 
    process_entire_document=FDR, 
    driver_path=None, 
    save_to_cache=False, 
    load_from_cache=False,
    embeds_cache_read=True,
    embeds_cache_write=True,
)

In [None]:
if os.path.exists('scripts/exp_input/Local_fulltext_pub_REV.parquet'):
    publication_fulltext_df = pd.read_parquet('scripts/exp_input/Local_fulltext_pub_REV.parquet')
    publication_fulltext = publication_fulltext_df.to_dict(orient='index')
else:
    publication_fulltext = dg.fetch_data(pmcids,write_df_to_path='scripts/exp_input/Local_fulltext_pub_REV.parquet')
    publication_fulltext_df = pd.DataFrame.from_dict(publication_fulltext, orient='index')

In [None]:
format_counts = {}
for url, data in publication_fulltext.items():
    if url not in pmcids:
        continue
    if data and 'raw_data_format' in data:
        fmt = data['raw_data_format']
        if fmt not in format_counts:
            format_counts[fmt] = {'count': 0, 'urls': []}
        format_counts[fmt]['count'] += 1
        format_counts[fmt]['urls'].append(url)
            
# Log format frequencies (counts only for readability)
frequency_summary = {fmt: info['count'] for fmt, info in format_counts.items()}
dg.logger.info(f"Fetched {len(publication_fulltext)} Papers. Format frequencies: {frequency_summary}")

In [None]:
batch_requests, cnt, last_url_raw_data_format = [], 0, False

for url_raw_data_format, vals in format_counts.items():
    for url in vals['urls']:

        msg_already_added = ''
        
        data = publication_fulltext[url]
        dg.logger.info(f"type of data: {type(data['fetched_data'])}")

        if isinstance(data['fetched_data'], str) and url_raw_data_format.upper() == 'XML':
            dg.logger.info("string data is not supported input, need etree")
            data['fetched_data'] = etree.fromstring(data['fetched_data'].encode('utf-8'))

        try:                        
            if cnt != 0 and url_raw_data_format == last_url_raw_data_format:
                dg.logger.info(f"Reusing existing parser of name: {dg.parser.__class__.__name__}")
            else:
                dg.logger.info(f"Creating new parser for format: {url_raw_data_format}")
                dg.init_parser_by_input_type(url_raw_data_format, data['fetched_data'], embeddings_retriever_model)
                        
            # Generate unique custom_id
            article_id = dg.url_to_page_id(url)
            pmcid = dg.data_fetcher.url_to_pmcid(url)
            timestamp = int(time.time() * 1000)
            custom_id = f"{dg.llm}_{article_id}_{timestamp}"
            custom_id = re.sub(r'[^a-zA-Z0-9_-]', '_', custom_id)[:64]
                        
            if dg.full_document_read:
                dg.logger.info(f'normalize input')
                if url_raw_data_format.upper() == 'XML':
                    normalized_input = (dg.parser.normalize_XML(data['fetched_data']) 
                                        if hasattr(dg.parser, 'normalize_XML') 
                                        else data['fetched_data'])
                elif url_raw_data_format.upper() == 'HTML':
                    normalized_input = (dg.parser.normalize_HTML(data['fetched_data']) 
                                                if hasattr(dg.parser, 'normalize_HTML') 
                                                else data['fetched_data'])
                elif url_raw_data_format.upper() == 'PDF':
                    normalized_input = data['fetched_data']
                else:
                    raise ValueError(f"Unsupported raw data format: {url_raw_data_format}")
                        
            else:
                dg.logger.info(f'relevant section retrieval')
                data_availability_obj = dg.parser.retrieve_relevant_content(
                                data['fetched_data'],
                                semantic_retrieval=semantic_retrieval,
                                top_k=top_k,
                                skip_rule_based_retrieved_elm=dedup,
                                include_snippets_with_ID_patterns=brute_force_RegEx_ID_ptrs,
                                article_id=dg.data_fetcher.url_to_pmcid(url),
                                output_format='json'
                            )
                dg.logger.info(f"type of data_availability_obj: {type(data_availability_obj)}")
                dg.logger.info(f"length of data_availability_obj: {len(data_availability_obj)}")
                dg.logger.info(f"data_availability_obj content: {data_availability_obj}")

                for idx, obj in enumerate(data_availability_obj):
                    dg.logger.info(f"Object type in data_availability_obj: {type(obj)}")

                    # Extract text and metadata from obj
                    if isinstance(obj, dict) and 'text' in obj:
                        normalized_input = obj['text']
                        # Preserve all other attributes in metadata (excluding 'text')
                        obj_metadata = {k: v for k, v in obj.items() if k != 'text'}
                    elif isinstance(obj, str):
                        normalized_input = obj
                        obj_metadata = {}
                    else:
                        dg.logger.warning(f"Unsupported object type in data_availability_obj: {type(obj)}")
                        continue

                    # Render prompt using the correct parser
                    static_prompt = dg.parser.prompt_manager.load_prompt(prompt_name)
                    messages = dg.parser.prompt_manager.render_prompt(
                                    static_prompt,
                                    entire_doc=dg.full_document_read,
                                    content=normalized_input,
                                    repos=', '.join(dg.parser.repo_names) if hasattr(dg.parser, 'repo_names') else '',
                                    url=url
                                )
                    
                    # Create unique custom_id for each snippet
                    snippet_custom_id = f"{custom_id}_snippet_{idx}"
                                
                    # Create batch request for LLMClient
                    batch_request = {
                                    'custom_id': snippet_custom_id,
                                    'messages': messages,
                                    'metadata': {
                                        'url': url,
                                        'article_id': article_id,
                                        'raw_data_format': url_raw_data_format,
                                        'snippet_index': idx,
                                        **obj_metadata  # Preserve all attributes from obj
                                    }
                                }
                                
                    batch_requests.append(batch_request)
                        
        except Exception as e:
            dg.logger.error(f"Error preparing request for {url}: {e}")
            continue

        last_url_raw_data_format = url_raw_data_format
        cnt+=1
            
dg.logger.info(f"Prepared {len(batch_requests)} batch requests")

In [None]:
len(batch_requests)

In [None]:
batch_result = dg.parser.llm_client._handle_batch_mode(
                batch_requests=batch_requests,
                batch_file_path=batch_file_path,
                temperature=0,
                response_format=dataset_response_schema_gpt,
                api_provider='openai'
                )
            
result = {
    'batch_file_created': batch_result,
    'fetched_data_count': len(publication_fulltext),
    'processed_requests': len(batch_requests),
    'api_provider': 'openai',
    'model': dg.llm
    }

In [None]:
result

In [None]:
prompts_filepath = batch_file_path
prompts_load = []
with open(prompts_filepath, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            prompts_load.append(json.loads(line))
len(prompts_load)

In [None]:
found_avg, bad_input, empty_prompt = 0, 0, []
n_dfs = len(prompts_load)

for prompt in prompts_load:
    pmc_id = dg.data_fetcher.url_to_pmcid(prompt['custom_id'])
    gt = df_gt[df_gt['pmcid'] == pmc_id]
    datasets_gt = gt['identifier'].values.tolist()
    #print(f"datasets: {datasets_gt}")
    body_msg = [item['content'] for item in prompt['body']['input']]
    #print (f"prompt: {body_msg}")
    input_cont_str = "\n".join(body_msg)

    datasets_found, datasets_tot = 0, len(datasets_gt)
    contains_one = False
    for dataset in datasets_gt:
        if dataset.lower() in input_cont_str.lower():
            datasets_found += 1
            contains_one = True
            continue
        else:
            print(f"Missing dataset {dataset} in prompt {prompt['custom_id']} for pmcid {pmc_id}")
    if not contains_one:
        bad_input += 1
        empty_prompt.append(prompt['custom_id'])
    found_i = datasets_found / datasets_tot if datasets_tot > 0 else 1.0
    found_avg += found_i/n_dfs

In [None]:
found_avg, bad_input

In [None]:
# Simple chunking and submission - NO monitoring or result combination
result = dg.split_jsonl_and_submit(
    batch_file_path=batch_file_path,
    max_file_size_mb=200.0,
    api_provider='openai',
    wait_between_submissions=30,
    batch_description=f"Training Dataset Creation"
)

In [None]:
batch_train = 'batch_691bf0fcc7d081909d396c54e2082ced'

In [None]:
if not dg.parser:
    dg.init_parser_by_input_type('XML')

res = dg.parser.llm_client.download_batch_results(
    batch_id=batch_train,
    output_file_path=output_batch_file,
    api_provider='openai'
)

In [None]:
with open(output_batch_file, 'r') as f:
    lines = f.readlines()
print(f"Number of lines in combined file: {len(lines)}")

In [None]:
ret_file = 'scripts/output/semantic_search/Train_results.csv'
ret_file

In [None]:
res_df = dg.from_batch_resp_file_to_df(output_batch_file, output_file_path=ret_file)

In [None]:
res_df = pd.read_csv(ret_file)

In [None]:
pmcids_ret = set([re.sub('(https://www.ncbi.nlm.nih.gov/pmc/articles/.*)/','\\1',item).lower() for item in res_df['source_url'].to_list()])
pmcids = set([idx.lower() for idx in pmcids])
missing_urls = list(pmcids - pmcids_ret)
len(missing_urls)

In [None]:
new_datasets_append = dg.process_articles(
    missing_urls,
    prompt_name="GPT_FewShot",
    full_document_read=FDR,
    top_k = top_k,
    semantic_retrieval=semantic_retrieval,
    response_format=dataset_response_schema_gpt
)

In [None]:
len(new_datasets_append), type(new_datasets_append)

In [None]:
# union dataframes
for pmc_link in new_datasets_append.keys():
    final_df = pd.concat([res_df, new_datasets_append[pmc_link]], ignore_index=True)

In [None]:
final_df.to_csv(ret_file, index=False)

In [None]:
res_df = pd.read_csv(ret_file)
res_df.head(2)

In [None]:
len(res_df)

In [None]:
print('a')

In [None]:
# train_dict creation
train_dict = {}

# Load prompts from the request file
prompts_file = batch_file_path
prompts_data = {}

with open(prompts_file, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            prompt = json.loads(line)
            custom_id = prompt['custom_id']
            
            # Extract content values from body.input array
            input_messages = prompt['body']['input']
            content_values = [msg['content'] for msg in input_messages if 'content' in msg]
            just_context = content_values[1][1016:]
            
            # Store in prompts_data
            prompts_data[custom_id] = {
                'input_content': just_context,
                'metadata': prompt.get('metadata', {})
            }

#print(f"Loaded {len(prompts_data)} prompts from request file")

# Load results from the results file
results_file = output_batch_file
results_data = {}

with open(results_file, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            result = json.loads(line)
            custom_id = result['custom_id']
            
            # Extract output.text from response
            output_text = None
            if 'response' in result and 'body' in result['response']:
                body = result['response']['body']
                # Navigate: body -> output[0] -> content[0] -> text
                if 'output' in body and len(body['output']) > 0:
                    output_item = body['output'][0]
                    if 'content' in output_item and len(output_item['content']) > 0:
                        content_item = output_item['content'][0]
                        if 'text' in content_item:
                            output_text = content_item['text']
            
            results_data[custom_id] = {
                'output_text': output_text
            }

#print(f"Loaded {len(results_data)} results from results file")

# Combine prompts and results into train_dict
for custom_id in prompts_data.keys():
    if custom_id in results_data:
        train_dict[custom_id] = {
            'custom_id': custom_id,
            'input_content': prompts_data[custom_id]['input_content'],
            'output_text': results_data[custom_id]['output_text'],
            'metadata': prompts_data[custom_id]['metadata']
        }
    else:
        print(f"Warning: No result found for custom_id: {custom_id}")

print(f"\nCreated train_dict with {len(train_dict)} entries")
print(f"Missing results: {len(prompts_data) - len(train_dict)}")

In [None]:
1/0

In [None]:
len(train_dict)

In [None]:
iter_n = 0
for custom_id, entry in train_dict.items():
    print(custom_id)
    print(entry['metadata'])
    print(entry['input_content'])
    print(entry['output_text'])
    iter_n += 1
    if iter_n==3:
        break
    else:
        continue

## Dataset Validation
Convert train_dict to DataFrame for validation and export

In [None]:
# Convert train_dict to DataFrame for validation
validation_records = []

for custom_id, entry in train_dict.items():
    # Join input content into a single string for readability
    input_text = "\n---\n".join(entry['input_content']) if isinstance(entry['input_content'], list) else str(entry['input_content'])
    
    record = {
        'custom_id': entry['custom_id'],
        'url': entry['metadata'].get('url', ''),
        'article_id': entry['metadata'].get('article_id', ''),
        'raw_data_format': entry['metadata'].get('raw_data_format', ''),
        'snippet_index': entry['metadata'].get('snippet_index', ''),
        'section_title': entry['metadata'].get('section_title', ''),
        'sec_type': entry['metadata'].get('sec_type', ''),
        'L2_distance': entry['metadata'].get('L2_distance', ''),
        'input_text': input_text,
        'output_text': entry['output_text'],
        'validated': '',  # Empty column for manual validation
        'notes': ''  # Empty column for notes
    }
    validation_records.append(record)

validation_df = pd.DataFrame(validation_records)
print(f"Created validation DataFrame with {len(validation_df)} rows")
validation_df.head()

In [None]:
# Export to Excel for validation (recommended - preserves formatting and allows filtering)
validation_df.to_excel(excel_output_path, index=False, engine='openpyxl')
print(f"Exported to Excel: {excel_output_path}")

print("\nYou can now:")
print("1. Open the Excel file to validate with filtering, sorting, and cell-by-cell editing")
print("2. Open the CSV file in any spreadsheet application")
print("3. Use the DataFrame below for in-notebook validation")

### Option 1: Interactive In-Notebook Validation (Optional)
Use widgets to validate entries one by one

In [None]:
# Interactive validation widget (uncomment to use)
# import ipywidgets as widgets
# from IPython.display import display, HTML

# current_idx = 0
# validation_status = {}

# def show_entry(idx):
#     if idx >= len(validation_df):
#         print("End of dataset reached!")
#         return
    
#     entry = validation_df.iloc[idx]
    
#     print("=" * 80)
#     print(f"Entry {idx + 1} / {len(validation_df)}")
#     print("=" * 80)
#     print(f"Custom ID: {entry['custom_id']}")
#     print(f"URL: {entry['url']}")
#     print(f"Article ID: {entry['article_id']}")
#     print(f"Section: {entry['section_title']} ({entry['sec_type']})")
#     print(f"L2 Distance: {entry['L2_distance']}")
#     print("\n--- INPUT TEXT ---")
#     print(entry['input_text'][:1000] + "..." if len(str(entry['input_text'])) > 1000 else entry['input_text'])
#     print("\n--- OUTPUT TEXT ---")
#     print(entry['output_text'])
#     print("=" * 80)

# def mark_valid(b):
#     global current_idx
#     validation_status[current_idx] = 'valid'
#     validation_df.loc[current_idx, 'validated'] = 'valid'
#     current_idx += 1
#     show_entry(current_idx)

# def mark_invalid(b):
#     global current_idx
#     validation_status[current_idx] = 'invalid'
#     validation_df.loc[current_idx, 'validated'] = 'invalid'
#     current_idx += 1
#     show_entry(current_idx)

# def skip_entry(b):
#     global current_idx
#     current_idx += 1
#     show_entry(current_idx)

# # Create buttons
# valid_btn = widgets.Button(description="✓ Valid", button_style='success')
# invalid_btn = widgets.Button(description="✗ Invalid", button_style='danger')
# skip_btn = widgets.Button(description="→ Skip", button_style='info')

# valid_btn.on_click(mark_valid)
# invalid_btn.on_click(mark_invalid)
# skip_btn.on_click(skip_entry)

# # Display interface
# show_entry(current_idx)
# display(widgets.HBox([valid_btn, invalid_btn, skip_btn]))

### Load Validated Data Back
After validation in Excel/CSV, load the validated data back

In [None]:
# Load validated data back from Excel or CSV
# validated_df = pd.read_excel('scripts/output/training_dataset_validation.xlsx')
# # OR
# validated_df = pd.read_csv('scripts/output/training_dataset_validation.csv')

# Filter only valid entries
# valid_entries = validated_df[validated_df['validated'] == 'valid']
# print(f"Valid entries: {len(valid_entries)} / {len(validated_df)}")

# # Get statistics
# validation_summary = validated_df['validated'].value_counts()
# print("\nValidation Summary:")
# print(validation_summary)

In [None]:
# split pmcids into train / val / test using sklearn (70/15/15 by default)
test_size = 0.15
val_size = 0.15
random_state = 42

# first split off the test set
train_val, test_ids = sklearn.model_selection.train_test_split(
    pmcids, test_size=test_size, random_state=random_state, shuffle=True
)

# compute relative validation size w.r.t. the remaining data and split train/val
val_relative = val_size / (1.0 - test_size)
train_ids, val_ids = sklearn.model_selection.train_test_split(
    train_val, test_size=val_relative, random_state=random_state, shuffle=True
)

print(f"Total: {len(pmcids)}  -> train: {len(train_ids)}, val: {len(val_ids)}, test: {len(test_ids)}")