In [None]:
import json
import copy
import random
import pandas as pd
from cns_llava.utils import load_variable_json
from cns_llava.utils import save_variable_json

In [None]:
def remove_extra_images(conversation):
    image_count = 0
    for entry in conversation:
        # Filter out extra images from the 'content' list
        new_content = []
        for content in entry['content']:
            if content['type'] == 'image':
                image_count += 1
                if image_count == 1:
                    new_content.append(content)  # Keep the first image only
            else:
                new_content.append(content)  # Keep all non-image content
        entry['content'] = new_content
    return conversation

In [None]:
journal = "Neurosurgery"
typ = ""

df = pd.read_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both{typ}.json")

# Convert 'in_text_mention' to a tuple for merging
df['in_text_mention'] = df['in_text_mention'].apply(tuple)

# Separate 'ddx' and 'ift' entries and reset indices to create unique identifiers
ddx_entries = df[df['mode'] == 'ddx'].reset_index(drop=False)  # 'index' column will serve as unique ID
ift_entries = df[df['mode'] == 'ift'].reset_index(drop=False)

# Merge on all identical columns to find matching pairs, including the unique 'index' identifiers
merged = pd.merge(
    ddx_entries,
    ift_entries,
    on=['custom_id', 'source', 'fig_caption', 'fig_label', 'in_text_mention', 'image', 'paper', 'page'],
    suffixes=('_ddx', '_ift')
)

# Create the new concatenated entry for 'ddx_ift'
merged['mode'] = 'ddx_ift'
merged['conversations'] = merged['conversations_ddx'] + merged['conversations_ift']
merged['question'] = merged['question_ddx'] + ' ' + merged['question_ift']

# Convert 'in_text_mention' back to a list
merged['in_text_mention'] = merged['in_text_mention'].apply(list)

# Select relevant columns for the new 'ddx_ift' entries
ddx_ift_entries = merged[[
    'custom_id', 'fig_caption', 'fig_label', 'in_text_mention', 'image',
    'paper', 'page', 'question', 'mode', 'conversations', 'source'
]]

# Identify the indices of 'ddx' and 'ift' entries that were merged
ddx_indices_to_remove = merged['index_ddx']
ift_indices_to_remove = merged['index_ift']

# Remove only the merged 'ddx' and 'ift' entries from the original DataFrame
df_filtered = df.drop(index=ddx_indices_to_remove)
df_filtered = df_filtered.drop(index=ift_indices_to_remove)

# Convert 'in_text_mention' back to a list in the filtered DataFrame
df_filtered['in_text_mention'] = df_filtered['in_text_mention'].apply(list)

# Append the new 'ddx_ift' entries
result_df = pd.concat([df_filtered, ddx_ift_entries], ignore_index=True)

result_df['conversations'] = result_df['conversations'].apply(remove_extra_images)

In [None]:
for journal in ["Neurosurgery_Practice", "Operative_Neurosurgery", "Neurosurgery"]:
    for typ in ["", "_train", "_val", "_test"]:
        df = pd.read_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both{typ}.json")
        
        # Convert 'in_text_mention' to a tuple for merging
        df['in_text_mention'] = df['in_text_mention'].apply(tuple)
        
        # Separate 'ddx' and 'ift' entries and reset indices to create unique identifiers
        ddx_entries = df[df['mode'] == 'ddx'].reset_index(drop=False)  # 'index' column will serve as unique ID
        ift_entries = df[df['mode'] == 'ift'].reset_index(drop=False)
        
        # Merge on all identical columns to find matching pairs, including the unique 'index' identifiers
        merged = pd.merge(
            ddx_entries,
            ift_entries,
            on=['custom_id', 'source', 'fig_caption', 'fig_label', 'in_text_mention', 'image', 'paper', 'page'],
            suffixes=('_ddx', '_ift')
        )
        
        # Create the new concatenated entry for 'ddx_ift'
        merged['mode'] = 'ddx_ift'
        merged['conversations'] = merged['conversations_ddx'] + merged['conversations_ift']
        merged['question'] = merged['question_ddx'] + ' ' + merged['question_ift']
        
        # Convert 'in_text_mention' back to a list
        merged['in_text_mention'] = merged['in_text_mention'].apply(list)
        
        # Select relevant columns for the new 'ddx_ift' entries
        ddx_ift_entries = merged[[
            'custom_id', 'fig_caption', 'fig_label', 'in_text_mention', 'image',
            'paper', 'page', 'question', 'mode', 'conversations', 'source'
        ]]
        
        # Identify the indices of 'ddx' and 'ift' entries that were merged
        ddx_indices_to_remove = merged['index_ddx']
        ift_indices_to_remove = merged['index_ift']
        
        # Remove only the merged 'ddx' and 'ift' entries from the original DataFrame
        df_filtered = df.drop(index=ddx_indices_to_remove)
        df_filtered = df_filtered.drop(index=ift_indices_to_remove)
        
        # Convert 'in_text_mention' back to a list in the filtered DataFrame
        df_filtered['in_text_mention'] = df_filtered['in_text_mention'].apply(list)
        
        # Append the new 'ddx_ift' entries
        result_df = pd.concat([df_filtered, ddx_ift_entries], ignore_index=True)

        result_df['conversations'] = result_df['conversations'].apply(remove_extra_images)

        result_df.to_json(f"/gpfs/data/oermannlab/private_data/TheMedScrolls/FiguresJadenTextract/{journal}/full_journal_dataset_both_fix{typ}.json",
                         orient='records')