In [1]:
!pip install ollama

[0m

In [2]:
import ollama
import json
import pandas as pd
import numpy as np
import ast

In [3]:
with open('hyperedges_pubmed.json', 'r', encoding='utf-8') as f:
    data = json.load(f)
    hyperedges = {int(k): set(v) for k, v in data.items()}

In [None]:
sorted_hyperedges = dict(
    sorted(
        hyperedges.items(),
        key=lambda item: len(item[1]),   
        reverse=True                     
    )
)

print("Hyperedges ordenados por tamaño (de mayor a menor):")
for i, (hyperedge_id, nodes) in enumerate(sorted_hyperedges.items(), 1):
    print(f"{i}. Hyperedge {hyperedge_id}: {len(nodes)} nodos → {sorted(nodes)[:10]}{'...' if len(nodes) > 10 else ''}")
    if i >= 100: 
        break

print(f"\nTotal de hyperedges: {len(sorted_hyperedges)}")

Hyperedges ordenados por tamaño (de mayor a menor):
1. Hyperedge 9742976: 171 nodos → [9794859, 10718777, 10818026, 10938048, 10938049, 10986179, 11159708, 11285304, 11679387, 11801510]...
2. Hyperedge 8366922: 154 nodos → [9219407, 9294791, 9539789, 9562345, 9669786, 9844077, 9916130, 10333910, 10334318, 10388983]...
3. Hyperedge 11832527: 131 nodos → [7729603, 9075814, 9096977, 9742976, 9742977, 10189543, 10454950, 10934451, 11092283, 11311100]...
4. Hyperedge 19479186: 130 nodos → [1468304, 1516762, 3102297, 3943285, 6376244, 7587918, 7657800, 7658949, 7962273, 8013748]...
5. Hyperedge 18776148: 125 nodos → [1348306, 1469084, 1473616, 1971174, 1983329, 2134205, 2318983, 2387197, 2466379, 2499500]...
6. Hyperedge 18664617: 121 nodos → [1446808, 1598576, 1641770, 1737841, 1835345, 2119056, 2533502, 2644534, 3134263, 3290380]...
7. Hyperedge 17349009: 88 nodos → [1697648, 2190098, 6362005, 7537670, 7657822, 7694152, 7722468, 7822811, 7907110, 7907681]...
8. Hyperedge 16215165: 83 nodos

In [None]:
df = pd.read_csv("./PubMded_clusterwise.csv")[["PMID", "Title", "Abstract"]]

papers_ids_valid = set(df.PMID) 


print("Claves en hyperedges que NO están en papers_ids_valid:")
print(set(hyperedges.keys()) - papers_ids_valid)
print("Viceversa:")
print(papers_ids_valid - set(hyperedges.keys()))


filtered_hyperedges = {}
for k, v_set in hyperedges.items():
    if k in papers_ids_valid:
        filtered_values = v_set.intersection(papers_ids_valid)
        if filtered_values:  
            filtered_hyperedges[k] = filtered_values

print("Diccionario filtrado:")
print(len(filtered_hyperedges))
print(len(df))

print(set(filtered_hyperedges.keys()) - set(df.PMID))
print(set(df.PMID) - set(filtered_hyperedges.keys()))


all_values = set().union(*filtered_hyperedges.values())

missing_in_df = all_values - set(df.PMID)
extra_in_df = set(df.PMID) - all_values

print("Values del diccionario que NO están en df:", missing_in_df)
print("IDs en df que NO están en los values:", extra_in_df)


Claves en hyperedges que NO están en papers_ids_valid:
{17874530}
Viceversa:
set()
Diccionario filtrado:
19716
19716
set()
set()
Values del diccionario que NO están en df: set()
IDs en df que NO están en los values: set()


In [None]:
print(f"Number of hyperedges: {len(filtered_hyperedges)}")
print(f"Len dataset: {len(df)}")

to_remove_by_cluster = {}

# ROL
SYSTEM_PROMPT = """
You are an expert biomedical researcher specializing in diabetes, metabolism, retinal complications, mitochondrial function, and related topics.
The PubMed-Diabetes dataset contains 19,717 scientific publications from the PubMed database pertaining to diabetes, classified into one of three classes.
Your ONLY task is: Given a list of scientific papers (each with ID, Title, and Abstract) that belong to the same cluster/hyperedge that is based on a main paper of reference, identify and remove any that are clear outliers — i.e., papers that do not fit the dominant topic/theme of the main paper.

Rules:
- Focus on thematic coherence (main disease model, organ/system studied, experimental approach, etc.).
- Be strict: only remove if the paper is noticeably off-topic compared to the majority.
- Output format: ONLY a Python list of IDs to remove, e.g. [12187484, 2344352] or [] if none should be removed.
- Do NOT explain, do NOT add any text, do NOT say "I think" or "here are...". Just the list in valid Python syntax like [id1, id2].
"""


sorted_hyperedges = dict(
    sorted(
        filtered_hyperedges.items(),
        key=lambda item: len(item[1]),   
        reverse=True                     
    )
)

top_50_hyperedges = dict(list(sorted_hyperedges.items())[:50])


for node_id, nodes_in_edge in top_50_hyperedges.items():

    print(f"\nProcesando hipergrafo de {node_id} ({len(nodes_in_edge)} papers)")
    
    sub_df = df[df['PMID'].isin(nodes_in_edge)]
    if len(sub_df) >= 67:
        sample_df = sub_df.sample(n=30, random_state=42)    
    else:
        sample_df = sub_df.sample(n=20, random_state=42)

    print(f"Length del sample: {len(sample_df)}")

    # Prompt del usuario
    user_prompt = "Here are some sample scientific articles from the same cluster/hyperedge:\n\n"
    for _, row in sample_df.iterrows():
        user_prompt += f"ID: {row['PMID']}\nTitle: {row['Title']}\nAbstract: {row['Abstract'][:800]}...\n\n"  
    
    node_row = df.loc[df['PMID'] == node_id].iloc[0]

    user_prompt += (
        f"Based on the hyperedges's main paper's Title and Abstract:\n"
        f"Title: {node_row['Title']}\n"
        f"Abstract: {node_row['Abstract'][:800]}\n"
        "Decide which IDs should be removed because they do not fit the general topic/theme well."
    )

    response = ollama.chat(
        model='gemma3:1b',
        messages=[
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': user_prompt}
        ],
        options={
            'temperature': 0.1     # bajo para determinista
        }
    )
    
    raw_output = response['message']['content'].strip()
    print("Model output:", raw_output)
    
    try:
        ids_to_remove = ast.literal_eval(raw_output)
        if isinstance(ids_to_remove, list):
            to_remove_by_cluster[node_id] = [int(id) for id in ids_to_remove] 
            print(f"→ IDs a eliminar: {ids_to_remove}")
        else:
            print("No es lista válida")
    except Exception as e:
        print(f"Error parseando: {e} → output crudo: {raw_output}")


all_removed = [id for ids in to_remove_by_cluster.values() for id in ids]
print(f"\nTotal IDs sugeridos para remover: {len(all_removed)}")
print("IDs:", all_removed)

Number of hyperedges: 19716
Len dataset: 19716

Procesando hipergrafo de 9742976 (171 papers)
Length del sample: 30
Model output: Based on the provided information, the IDs that seem least relevant to the overall topic are:

*   **ID: 18620046** - This ID seems unrelated to the core research question of identifying high-risk adults.

The remaining IDs – **ID: 15616245, ID: 17597523, ID: 19056645** – are the most relevant and could be removed as they don't contribute significantly to the overall theme of diabetes prevention and risk assessment.
Error parseando: invalid character '–' (U+2013) (<unknown>, line 5) → output crudo: Based on the provided information, the IDs that seem least relevant to the overall topic are:

*   **ID: 18620046** - This ID seems unrelated to the core research question of identifying high-risk adults.

The remaining IDs – **ID: 15616245, ID: 17597523, ID: 19056645** – are the most relevant and could be removed as they don't contribute significantly to the over

KeyboardInterrupt: 