In [1]:
from datasets import load_from_disk
import pandas as pd
from collections import Counter
import re

In [2]:
!cp -r /kaggle/input/clinical/data /kaggle/working/data

In [3]:
ds = load_from_disk("/kaggle/working/data")
print(ds)
dev_size = 500_000
ds_dev = ds.shuffle(seed=42).select(range(min(dev_size, len(ds))))

Dataset({
    features: ['nct_id', 'brief_title_clean', 'brief_summary_clean', 'detailed_description_clean', 'eligibility_criteria_clean', 'keywords_clean', 'mesh_terms_clean', 'condition_browse_module_clean', 'intervention_browse_module_clean', 'conditions', 'interventions', 'combined_text', 'text_len'],
    num_rows: 479038
})


In [4]:
STOPWORDS = set([
    "and","or","of","in","all","by","with","without","a","the","an","to","for","type","site",
    "before","after","non","at","vs","i","ii","iii","iv","1","2","3","4","5",
    "patient", "patients", "study", "studies", "disease", "diseases", 
    "treatment", "treatments", "group", "groups", "criteria", "subject", 
    "subjects", "trial", "trials", "clinical", "randomized", "placebo",
    "intervention", "efficacy", "safety", "evaluate", "method", "results" # common words that ruin the statistics
])

def valid_token(tok):
    tok = tok.lower().strip(",.()")
    if tok in STOPWORDS:
        return False
    if len(tok) <= 2 and not tok.isalpha():
        return False
    if tok.isdigit():
        return False
    return True

def get_tokens_from_field(dataset, field):
    for x in dataset[field]:
        if not x:
            continue
        for tok in str(x).split():
            yield tok.strip().lower()

anchor_fields = [
    "mesh_terms_clean",
    "condition_browse_module_clean",
    "intervention_browse_module_clean",
]

cnt = Counter()
for field in anchor_fields:
    for tok in get_tokens_from_field(ds_dev, field):
        tok = tok.lower().strip(",.()")
        if valid_token(tok):
            cnt[tok] += 1

In [5]:
top_tokens = [tok for tok, _ in cnt.most_common(2000)]

In [6]:
from gensim.models import Word2Vec
from sklearn.cluster import KMeans
import numpy as np

class ClinicalSentences:
    def __init__(self, dataset, fields):
        self.dataset = dataset
        self.fields = fields
        
    def __iter__(self):
        for row in self.dataset:
            for field in self.fields:
                text = row.get(field)
                if text:
                    yield [t.strip().lower() for t in str(text).split()]


sentences = ClinicalSentences(ds_dev, anchor_fields)
model = Word2Vec(sentences, vector_size=100, window=5, min_count=10, workers=4)


token_vectors = []
valid_tokens = []

for tok in top_tokens:
    if tok in model.wv:
        token_vectors.append(model.wv[tok])
        valid_tokens.append(tok)


num_clusters = 20
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
labels = kmeans.fit_predict(token_vectors)

generated_anchors = {}
for tok, label in zip(valid_tokens, labels):
    if label not in generated_anchors:
        generated_anchors[label] = []
    generated_anchors[label].append(tok)

print(f"\nGenerated {num_clusters} Anchor Groups:")
print("-" * 50)
for cid in sorted(generated_anchors.keys()):
    terms = generated_anchors[cid]
    print(f"Group {cid} ({len(terms)} terms): {', '.join(terms[:15])}...")

NEW_ANCHOR_GROUPS = {f"group_{i}": terms for i, terms in generated_anchors.items()}




Generated 20 Anchor Groups:
--------------------------------------------------
Group 0 (13 terms): antagonists, modulators, agonists, hypnotics, sedatives, alpha-2, alpha-agonists, beta-agonists, beta-antagonists, alpha-1, beta-2, alpha-antagonists, beta-1...
Group 1 (217 terms): insulin, other, substitutes, hematinics, b, antibiotics, vaccines, zinc, lipid, c, complex, globin, d, folate, iron...
Group 2 (421 terms): system, infant, traumatic, newborn, knee, low, rheumatoid, nevus, major, influenza, fever, psychological, tuberculosis, post-traumatic, defects...
Group 3 (82 terms): gastrointestinal, liver, intestinal, ulcer, anatomical, bowel, resistance, abdominal, cirrhosis, allergic, fistula, vomiting, thromboembolism, crohn, polyps...
Group 4 (151 terms): immunologic, antibodies, immunoglobulins, paclitaxel, immunosuppressive, vitamins, folic, tyrosine, cyclophosphamide, mitosis, doxorubicin, monoclonal, phytogenic, carboplatin, fludarabine...
Group 5 (91 terms): agents, drugs, che

In [7]:
import json

with open("anchor_groups.json", "w") as f:
    json.dump(NEW_ANCHOR_GROUPS, f, indent=2)

In [8]:
from scipy.spatial.distance import cosine

MIN_TERMS = 5
cleaned_anchors = {cid: terms for cid, terms in generated_anchors.items() if len(terms) >= MIN_TERMS}

MAX_RATIO = 0.20
total_tokens = len(valid_tokens)
cleaned_anchors = {cid: terms for cid, terms in cleaned_anchors.items() if len(terms) < (total_tokens * MAX_RATIO)}


centroids = {}
for cid, terms in cleaned_anchors.items():
    vectors = [model.wv[t] for t in terms if t in model.wv]
    centroids[cid] = np.mean(vectors, axis=0)

merged_anchors = cleaned_anchors.copy()

print(f"Original groups: {num_clusters} | Cleaned groups: {len(cleaned_anchors)}")

Original groups: 20 | Cleaned groups: 17


In [9]:
import json


cleaned_anchors_serializable = {str(k): v for k, v in cleaned_anchors.items()}

with open("anchor_groups_cleaned.json", "w") as f:
    json.dump(cleaned_anchors_serializable, f, indent=2)
