In [None]:
!wget -O data.zip https://github.com/MMartinelli-hub/ATA_Tutorship/raw/refs/heads/main/gutbrainie/data/GutBrainIE_Full_Collection_2025.zip
!unzip data.zip "*"
!mkdir data
!mv Annotations/ ./data
!mv Articles/ ./data
!mv Test_Data/ ./data
!rm -rf data.zip

## Load Training Data and Build Relation Vocabulary

In [5]:
import json
import os
from collections import defaultdict

# Load training data from three quality levels
train_files = [
    "../data/Annotations/Train/gold_quality/json_format/train_gold.json",
    "../data/Annotations/Train/platinum_quality/json_format/train_platinum.json",
    "../data/Annotations/Train/silver_quality/json_format/train_silver.json"
]

# Dictionary to store relation triples: (subject_label, predicate, object_label) -> count
relation_vocab = defaultdict(int)

# Also keep track of actual text spans for better matching
# Structure: (subject_text, subject_label, predicate, object_text, object_label) -> count
relation_vocab_with_text = defaultdict(int)

for train_file in train_files:
    if os.path.exists(train_file):
        with open(train_file, "r", encoding="utf-8") as f:
            train_data = json.load(f)
        
        for pmid, article in train_data.items():
            for relation in article['relations']:
                subject_text = relation['subject_text_span']
                subject_label = relation['subject_label']
                predicate = relation['predicate']
                object_text = relation['object_text_span']
                object_label = relation['object_label']
                
                # Count by labels only
                relation_vocab[(subject_label, predicate, object_label)] += 1
                
                # Count with text spans
                relation_vocab_with_text[(subject_text, subject_label, predicate, object_text, object_label)] += 1
        
        print(f"Loaded {train_file}")
    else:
        print(f"Warning: {train_file} not found")

print(f"\nTotal unique relation triples (by label): {len(relation_vocab)}")
print(f"Total unique relation triples (with text): {len(relation_vocab_with_text)}")
print(f"Total relation instances: {sum(relation_vocab.values())}")

print(f"\nTop 10 most frequent relations (by text + label):")
sorted_relations = sorted(relation_vocab_with_text.items(), key=lambda x: x[1], reverse=True)
for (subj_text, subj_label, pred, obj_text, obj_label), count in sorted_relations[:10]:
    print(f"  {count:4d}x  '{subj_text}' [{subj_label}] --[{pred}]--> '{obj_text}' [{obj_label}]")

Loaded ../data/Annotations/Train/gold_quality/json_format/train_gold.json
Loaded ../data/Annotations/Train/platinum_quality/json_format/train_platinum.json
Loaded ../data/Annotations/Train/silver_quality/json_format/train_silver.json

Total unique relation triples (by label): 54
Total unique relation triples (with text): 9431
Total relation instances: 14065

Top 10 most frequent relations (by text + label):
   189x  'gut microbiota' [microbiome] --[is linked to]--> 'PD' [DDF]
   115x  'PD' [DDF] --[target]--> 'patients' [human]
    85x  'gut microbiota' [microbiome] --[is linked to]--> 'Parkinson's disease' [DDF]
    62x  'gut microbiota' [microbiome] --[located in]--> 'patients' [human]
    49x  'gut microbiota' [microbiome] --[is linked to]--> 'depression' [DDF]
    42x  'PM' [chemical] --[influence]--> 'PD' [DDF]
    36x  'MND' [DDF] --[target]--> 'patients' [human]
    35x  'gut microbiota' [microbiome] --[is linked to]--> 'neurodegenerative diseases' [DDF]
    34x  'PD' [DDF] --[t

## Load Dev Data

In [6]:
# Load dev data
dev_data_path = "../data/Annotations/Dev/json_format/dev.json"

with open(dev_data_path, "r", encoding="utf-8") as f:
    dev_data = json.load(f)

print(f"Loaded {len(dev_data)} documents from dev set")

# Count total entities in dev set
total_entities = sum(len(article['entities']) for article in dev_data.values())
print(f"Total entities in dev set: {total_entities}")

Loaded 40 documents from dev set
Total entities in dev set: 1117

Total entities in dev set: 1117


## Predict Relations on Dev Set

In [13]:
from tqdm import tqdm

def find_relations_for_entities(entities, relation_vocab_with_text):
    """
    Given a list of entities from a document, find all possible relations
    based on the relation vocabulary.
    
    For each entity as a potential subject:
    1. Check if there's a relation in the vocabulary with this subject text+label
    2. For each such relation, look for an object entity with matching text+label
    3. If found, create the relation with the predicate having highest count
    
    Args:
        entities: List of entity dictionaries with keys: start_idx, end_idx, location, text_span, label
        relation_vocab_with_text: Dict mapping (subject_text, subject_label, predicate, object_text, object_label) -> count
        
    Returns:
        List of predicted relations
    """
    predicted_relations = []
    
    # Index entities by (text_span, label) for faster lookup
    entities_by_text_label = defaultdict(list)
    for entity in entities:
        key = (entity['text_span'], entity['label'])
        entities_by_text_label[key].append(entity)
    
    # For each entity as potential subject
    for subject_entity in entities:
        subject_text = subject_entity['text_span']
        subject_label = subject_entity['label']
        
        # Find all possible relations where this text+label is the subject
        # Group by (object_text, object_label) to find best predicate
        possible_relations = defaultdict(list)  # (obj_text, obj_label) -> [(predicate, count)]
        
        for (subj_txt, subj_lbl, pred, obj_txt, obj_lbl), count in relation_vocab_with_text.items():
            if subj_txt == subject_text and subj_lbl == subject_label:
                possible_relations[(obj_txt, obj_lbl)].append((pred, count))
        
        # For each possible object text+label, check if we have such entities
        for (obj_text, obj_label), pred_counts in possible_relations.items():
            if (obj_text, obj_label) in entities_by_text_label:
                # Find the predicate with highest count
                best_predicate = max(pred_counts, key=lambda x: x[1])[0]
                
                # Create relations with all matching object entities
                for object_entity in entities_by_text_label[(obj_text, obj_label)]:
                    # Don't create self-relations
                    if subject_entity != object_entity:
                        predicted_relations.append({
                            "subject_start_idx": subject_entity['start_idx'],
                            "subject_end_idx": subject_entity['end_idx'],
                            "subject_location": subject_entity['location'],
                            "subject_text_span": subject_entity['text_span'],
                            "subject_label": subject_entity['label'],
                            "predicate": best_predicate,
                            "object_start_idx": object_entity['start_idx'],
                            "object_end_idx": object_entity['end_idx'],
                            "object_location": object_entity['location'],
                            "object_text_span": object_entity['text_span'],
                            "object_label": object_entity['label']
                        })
    
    return predicted_relations

# Process each document in dev set
predictions = {}

for pmid, article in tqdm(dev_data.items(), desc="Processing dev data"):
    entities = article['entities']
    # Predict relations based on entities and vocabulary (using text+label)
    predicted_relations = find_relations_for_entities(entities, relation_vocab_with_text)
    # Store predictions
    predictions[pmid] = {
        "metadata": article["metadata"],
        "entities": entities,
        "relations": predicted_relations
    }


print(f"\nProcessed {len(predictions)} documents")
print(f"Total relations predicted: {sum(len(p['relations']) for p in predictions.values())}")

Processing dev data: 100%|██████████| 40/40 [00:00<00:00, 60.31it/s]


Processed 40 documents
Total relations predicted: 1365





## Save Predictions

In [14]:
# Save predictions to file
output_path = "predictions/vanilla_re_predictions.json"

# Create predictions directory if it doesn't exist
os.makedirs("predictions", exist_ok=True)

with open(output_path, "w", encoding="utf-8") as f:
    json.dump(predictions, f, ensure_ascii=False, indent=2)

print(f"Predictions saved to {output_path}")

Predictions saved to predictions/vanilla_re_predictions.json


## Example Predictions

In [11]:
# Show example predictions
print("Example Predictions:\n")

# Get first few documents
sample_pmids = list(dev_data.keys())[:5]

for pmid in sample_pmids:
    article = dev_data[pmid]
    pred = predictions[pmid]
    
    print(f"Document PMID: {pmid}")
    print(f"Title: {article['metadata']['title'][:100]}...")
    print(f"\nNumber of entities: {len(article['entities'])}")
    print(f"Number of gold relations: {len(article['relations'])}")
    print(f"Number of predicted relations: {len(pred['relations'])}")
    
    # Show first few predicted relations
    print("\nSample predicted relations:")
    for relation in pred['relations'][:5]:
        print(f"  ({relation['subject_text_span']} [{relation['subject_label']}]) "
              f"--[{relation['predicate']}]--> "
              f"({relation['object_text_span']} [{relation['object_label']}])")
    
    print("-" * 80)
    print()

Example Predictions:

Document PMID: 36532064
Title: Hypothesis of a potential BrainBiota and its relation to CNS autoimmune inflammation....

Number of entities: 19
Number of gold relations: 12
Number of predicted relations: 29

Sample predicted relations:
  (neurological diseases [DDF]) --[change abundance]--> (gut microbiota [microbiome])
  (neurological diseases [DDF]) --[change abundance]--> (gut microbiota [microbiome])
  (neurological diseases [DDF]) --[target]--> (patients [human])
  (gut microbiota [microbiome]) --[is linked to]--> (neurological diseases [DDF])
  (gut microbiota [microbiome]) --[is linked to]--> (neurological diseases [DDF])
--------------------------------------------------------------------------------

Document PMID: 37212075
Title: IgA-Biome Profiles Correlate with Clinical Parkinson's Disease Subtypes....

Number of entities: 21
Number of gold relations: 11
Number of predicted relations: 15

Sample predicted relations:
  (Parkinson's disease [DDF]) --[cha

## Analysis: Relation Distribution by Predicate

In [12]:
from collections import Counter

# Count relations by predicate in predictions
pred_predicate_counts = Counter()
for pmid, pred in predictions.items():
    for relation in pred['relations']:
        pred_predicate_counts[relation['predicate']] += 1

# Count relations by predicate in gold standard
gold_predicate_counts = Counter()
for pmid, article in dev_data.items():
    for relation in article['relations']:
        gold_predicate_counts[relation['predicate']] += 1

print("Relation Distribution by Predicate:")
print("="*60)
print(f"{'Predicate':<30} {'Gold':<10} {'Predicted':<10}")
print("-"*60)

all_predicates = set(gold_predicate_counts.keys()) | set(pred_predicate_counts.keys())
for predicate in sorted(all_predicates):
    print(f"{predicate:<30} {gold_predicate_counts[predicate]:<10} {pred_predicate_counts[predicate]:<10}")

print("-"*60)
print(f"{'TOTAL':<30} {sum(gold_predicate_counts.values()):<10} {sum(pred_predicate_counts.values()):<10}")

print("\n" + "="*60)
print("Relation Distribution by Subject-Predicate-Object Pattern:")
print("="*60)

# Show top predicted patterns
pred_pattern_counts = Counter()
for pmid, pred in predictions.items():
    for relation in pred['relations']:
        pattern = (relation['subject_label'], relation['predicate'], relation['object_label'])
        pred_pattern_counts[pattern] += 1

print("\nTop 15 predicted patterns:")
for (subj, pred, obj), count in pred_pattern_counts.most_common(15):
    print(f"  {count:4d}x  ({subj}) --[{pred}]--> ({obj})")

Relation Distribution by Predicate:
Predicate                      Gold       Predicted 
------------------------------------------------------------
administered                   16         1         
affect                         89         162       
change abundance               11         149       
change effect                  32         8         
change expression              4          0         
compared to                    4          0         
impact                         54         95        
influence                      68         58        
interact                       9          6         
is a                           31         8         
is linked to                   74         374       
located in                     99         273       
part of                        13         56        
produced by                    6          38        
strike                         11         63        
target                         85         72        
us

## Processing to evaluation format

Remove relations not defined in the annotation guidelines and complete conversion to evaluation format

In [17]:
LEGAL_RELATIONS = [
    ("DDF", "affect", "DDF"),
    ("microbiome", "is linked to", "DDF"),
    ("DDF", "target", "human"),
    ("drug", "change effect", "DDF"),
    ("DDF", "is a", "DDF"),
    ("microbiome", "located in", "human"),
    ("chemical", "influence", "DDF"),
    ("dietary supplement", "influence", "DDF"),
    ("DDF", "target", "animal"),
    ("chemical", "impact", "microbiome"),
    ("anatomical location", "located in", "animal"),
    ("microbiome", "located in", "animal"),
    ("chemical", "located in", "anatomical location"),
    ("bacteria", "part of", "microbiome"),
    ("DDF", "strike", "anatomical location"),
    ("drug", "administered", "animal"),
    ("bacteria", "influence", "DDF"),
    ("drug", "impact", "microbiome"),
    ("DDF", "change abundance", "microbiome"),
    ("microbiome", "located in", "anatomical location"),
    ("microbiome", "used by", "biomedical technique"),
    ("chemical", "produced by", "microbiome"),
    ("dietary supplement", "impact", "microbiome"),
    ("bacteria", "located in", "animal"),
    ("animal", "used by", "biomedical technique"),
    ("chemical", "impact", "bacteria"),
    ("chemical", "located in", "animal"),
    ("food", "impact", "bacteria"),
    ("microbiome", "compared to", "microbiome"),
    ("human", "used by", "biomedical technique"),
    ("bacteria", "change expression", "gene"),
    ("chemical", "located in", "human"),
    ("drug", "interact", "chemical"),
    ("food", "administered", "human"),
    ("DDF", "change abundance", "bacteria"),
    ("chemical", "interact", "chemical"),
    ("chemical", "part of", "chemical"),
    ("dietary supplement", "impact", "bacteria"),
    ("DDF", "interact", "chemical"),
    ("food", "impact", "microbiome"),
    ("food", "influence", "DDF"),
    ("bacteria", "located in", "human"),
    ("dietary supplement", "administered", "human"),
    ("bacteria", "interact", "chemical"),
    ("drug", "change expression", "gene"),
    ("drug", "impact", "bacteria"),
    ("drug", "administered", "human"),
    ("anatomical location", "located in", "human"),
    ("dietary supplement", "change expression", "gene"),
    ("chemical", "change expression", "gene"),
    ("bacteria", "interact", "bacteria"),
    ("drug", "interact", "drug"),
    ("microbiome", "change expression", "gene"),
    ("bacteria", "interact", "drug"),
    ("food", "change expression", "gene")
]

def remove_illegal_relations(data):
    dump_dict = {}
    total_rels = 0
    kept_rels = 0
    discared_rels = 0
    discared_rels_set = set()

    for pmid, article in data.items():
        dump_dict[pmid] = {}
        dump_dict[pmid]['metadata'] = article['metadata']
        dump_dict[pmid]['entities'] = []
        dump_dict[pmid]['relations'] = []

        for entity in article['entities']:
            dump_dict[pmid]['entities'].append({
                "start_idx": entity["start_idx"],
                "end_idx": entity["end_idx"],
                "location": entity["location"],
                "text_span": entity["text_span"],
                "label": entity["label"] if entity['label'] != 'DDF' else 'DDF'
            })
        
        for relation in article['relations']:
            total_rels += 1
            rel_key = (relation["subject_label"], relation["predicate"], relation["object_label"])
            if rel_key in LEGAL_RELATIONS:
                kept_rels += 1
                dump_dict[pmid]['relations'].append({
                    "subject_start_idx": relation["subject_start_idx"],
                    "subject_end_idx": relation["subject_end_idx"],
                    "subject_location": relation["subject_location"],
                    "subject_text_span": relation["subject_text_span"],
                    "subject_label": relation["subject_label"] if relation["subject_label"] != 'DDF' else 'DDF',
                    "predicate": relation["predicate"],
                    "object_start_idx": relation["object_start_idx"],
                    "object_end_idx": relation["object_end_idx"],
                    "object_location": relation["object_location"],
                    "object_text_span": relation["object_text_span"],
                    "object_label": relation["object_label"] if relation["object_label"] != 'DDF' else 'DDF'
                })
            else:
                discared_rels += 1
                discared_rels_set.add(rel_key)

    print(f'total_rels: {total_rels}')
    print(f'kept_rels: {kept_rels}')
    print(f'discared_rels: {discared_rels}')
    print()
    print(f'discared_rels_set: {discared_rels_set}')
    for entry in discared_rels_set:
        print(entry)

    return dump_dict


In [18]:
dump_dict = remove_illegal_relations(predictions)

total_rels: 1365
kept_rels: 1365
discared_rels: 0

discared_rels_set: set()


Sort entities and relations

In [19]:
def sort_entities(release_dict):
	def get_sorting_key(entity):
		location_priority = 0 if entity["location"] == "title" else 1
		return (location_priority, entity["start_idx"])

	for pmid, article in release_dict.items():
		article["entities"].sort(key=get_sorting_key)

In [20]:
sort_entities(predictions)

In [21]:
def sort_relations(release_dict):
	def get_sorting_key(relation):
		location_priority = 0 if relation["subject_location"] == "title" else 1
		return (location_priority, relation["subject_start_idx"])

	for pmid, article in release_dict.items():
		article["relations"].sort(key=get_sorting_key)

In [22]:
sort_relations(predictions)

Generate Binary Tag Based Relations

In [23]:
def add_binary_tag_based_relations_to_release_dict(release_dict):
    for pmid, article in release_dict.items():
        pairs = set()
        for relation in article["relations"]:
            pairs.add((relation["subject_label"], relation["object_label"]))
        if "binary_tag_based_relations" not in release_dict[pmid]:    
            release_dict[pmid]["binary_tag_based_relations"] = []
        for entry in pairs:
            release_dict[pmid]["binary_tag_based_relations"].append({"subject_label": entry[0], "object_label": entry[1]})

In [24]:
add_binary_tag_based_relations_to_release_dict(predictions)

Generate Ternary Tag Based Relations

In [25]:
def add_ternary_tag_based_relations_to_release_dict(release_dict):
    for pmid, article in release_dict.items():
        triplets = set()
        for relation in article["relations"]:
            triplets.add((relation["subject_label"], relation["predicate"], relation["object_label"]))
        if "ternary_tag_based_relations" not in release_dict[pmid]:
            release_dict[pmid]["ternary_tag_based_relations"] = []
        for entry in triplets:
            release_dict[pmid]["ternary_tag_based_relations"].append({"subject_label": entry[0], "predicate": entry[1], "object_label": entry[2]})

In [26]:

add_ternary_tag_based_relations_to_release_dict(predictions)

Generate Ternary Mention Based Relations

In [28]:
def add_ternary_mention_based_relations_to_release_dict(release_dict):
    for pmid, article in release_dict.items():
        tuples = set()
        for relation in article["relations"]:
            tuples.add((relation["subject_text_span"], relation["subject_label"], relation["predicate"], relation["object_text_span"], relation["object_label"]))
        if "ternary_mention_based_relations" not in release_dict[pmid]:
            release_dict[pmid]["ternary_mention_based_relations"] = []		
        for entry in tuples:
            release_dict[pmid]["ternary_mention_based_relations"].append({"subject_text_span": entry[0], "subject_label": entry[1], "predicate": entry[2], "object_text_span": entry[3], "object_label": entry[4]})

In [29]:
add_ternary_mention_based_relations_to_release_dict(predictions)

In [30]:
output_path = output_path.replace(".json", "_eval_format.json")

with open(output_path, 'w', encoding='utf-8') as file:
    json.dump(predictions, file, indent=2)