In [1]:
import re
import numpy as np
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel
import torch
import nltk
from nltk.corpus import stopwords
import string
from gensim.models import Word2Vec


stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

ner_dataset = load_dataset("tner/bc5cdr")


test_sentences = ner_dataset['test']['tokens']
test_tags = ner_dataset['test']['tags']

tag_to_entity = {0: 'O', 1: 'B-Chemical', 2: 'B-Disease', 3: 'I-Chemical', 4: 'I-Disease'}


def extract_entities_debug(sentences, tags, entity_type="Disease"):
    entities = []
    for sentence, tag_seq in zip(sentences, tags):
        entity = []
        for token, tag in zip(sentence, tag_seq):
            tag_label = tag_to_entity.get(tag, 'O')
            if tag_label == f"B-{entity_type}":
                if entity:
                    entities.append(" ".join(entity))
                    entity = []
                entity.append(token)
            elif tag_label == f"I-{entity_type}":
                entity.append(token)
            else:
                if entity:
                    entities.append(" ".join(entity))
                    entity = []
        if entity:
            entities.append(" ".join(entity))
    print("Extracted Entities before deduplication:", entities)
    return list(set(entities)) 


disease_entities = extract_entities_debug(test_sentences, test_tags)
print("Extracted Disease Entities:", disease_entities)


if not disease_entities:
    print("No disease entities found. Please check the tags in the dataset.")
else:
    
    def preprocess_entities(entities):
        filtered_entities = []
        for entity in entities:
            entity = entity.lower() 
            entity = re.sub(r'[^a-z\s]', '', entity)  
            words = entity.split()
            filtered_words = [word for word in words if word not in stop_words and word not in punctuation]
            if filtered_words:
                filtered_entities.append(filtered_words)
        return filtered_entities

    
    disease_entities_processed = preprocess_entities(disease_entities)
    print("Preprocessed Disease Entities:", disease_entities_processed)

    
    if not disease_entities_processed:
        raise ValueError("No valid disease entities found after preprocessing.")

    
    disease_entities_joined = [" ".join(entity) for entity in disease_entities_processed]

    
    
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = AutoModel.from_pretrained("bert-base-uncased")

    
    def get_bert_embeddings(text_list):
        inputs = tokenizer(text_list, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        return outputs.last_hidden_state.mean(dim=1)

    
    embeddings_bert = get_bert_embeddings(disease_entities_joined)

    
    word2vec_model = Word2Vec(sentences=disease_entities_processed, vector_size=100, window=5, min_count=1, workers=4)

    
    def get_word2vec_embeddings(entities, model):
        embeddings = []
        for entity in entities:
            embedding = np.mean([model.wv[word] for word in entity if word in model.wv], axis=0)
            embeddings.append(embedding)
        return np.array(embeddings)

    
    embeddings_w2v = get_word2vec_embeddings(disease_entities_processed, word2vec_model)

    
    query_disease = "dyskinesia"  
    query_disease_processed = [word for word in query_disease.lower().split() if word not in stop_words and word not in punctuation]

    if not query_disease_processed:
        raise ValueError(f"The query disease '{query_disease}' is not valid after preprocessing.")

    
    query_embedding_bert = get_bert_embeddings([" ".join(query_disease_processed)])
    query_embedding_w2v = np.mean([word2vec_model.wv[word] for word in query_disease_processed if word in word2vec_model.wv], axis=0)

    
    cosine_similarities_bert = cosine_similarity(query_embedding_bert, embeddings_bert).flatten()

    
    similar_indices_bert = cosine_similarities_bert.argsort()[-6:-1][::-1]
    dissimilar_indices_bert = cosine_similarities_bert.argsort()[:5]

    similar_diseases_bert = [(disease_entities[i], cosine_similarities_bert[i]) for i in similar_indices_bert]
    dissimilar_diseases_bert = [(disease_entities[i], cosine_similarities_bert[i]) for i in dissimilar_indices_bert]

    print("BERT Similar Diseases:", similar_diseases_bert)
    print("BERT Dissimilar Diseases:", dissimilar_diseases_bert)

    
    cosine_similarities_w2v = cosine_similarity([query_embedding_w2v], embeddings_w2v).flatten()

    
    similar_indices_w2v = cosine_similarities_w2v.argsort()[-6:-1][::-1]
    dissimilar_indices_w2v = cosine_similarities_w2v.argsort()[:5]

    similar_diseases_w2v = [(disease_entities[i], cosine_similarities_w2v[i]) for i in similar_indices_w2v]
    dissimilar_diseases_w2v = [(disease_entities[i], cosine_similarities_w2v[i]) for i in dissimilar_indices_w2v]

    print("Word2Vec Similar Diseases:", similar_diseases_w2v)
    print("Word2Vec Dissimilar Diseases:", dissimilar_diseases_w2v)

Extracted Entities before deduplication: ['delirium', 'ulcers', 'delirium', 'delirium', 'hypotension', 'scleroderma', 'Scleroderma', 'SRC', 'systemic', 'SSc', 'SRC', 'thrombotic', 'SSc', 'SRC', 'SSc', 'psychosis', 'psychosis', 'psychosis', 'psychosis', 'psychiatric', 'psychotic', 'psychotic', 'depressive', 'bipolar', 'antisocial', 'psychosis', 'Major', 'antisocial', 'psychosis', 'psychosis', 'affective', 'antisocial', 'psychotic', "Parkinson's", 'dyskinetic', "Parkinson's", 'PD', 'dyskinesias', 'LIDs', 'LIDs', 'LIDs', 'LIDs', 'abnormal', 'cystitis', 'cystitis', '82334', 'pain', 'pain', 'edema', 'cystitis', 'hepatitis', 'hepatotoxicity', 'hepatitis', 'hepatitis', 'hepatotoxicity', 'hepatic', 'multiple', 'multiple', 'MM', 'MM', 'peripheral', 'MM', 'A', 'anxiety', 'A', 'anxiety', 'anxiety', 'anxiety', 'anxiety', 'cardiotoxicity', 'Cardiovascular', 'CVDs', 'cardiotoxicity', 'acid', 'oxide', 'multiple', 'Myeloma', 'multiple', 'RRMM', 'RRMM', 'myelosuppression', 'Peripheral', 'deep', 'RRMM',

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BERT Similar Diseases: [('peptic', 0.7743856), ('accelerated', 0.7015696), ('2, 3, 4-tetrahydroanaphthalene', 0.699726), ('Hepatitis', 0.6866766), ('arabinoside', 0.677312)]
BERT Dissimilar Diseases: [('retinoblastoma', 0.22445168), ('menorrhagia', 0.24199176), ('hyperkalemia', 0.24508162), ('isotretinoin', 0.24509302), ('Depression', 0.24509302)]
Word2Vec Similar Diseases: [('peptic', 1.0), ('FQ', 0.32622227), ('Drug', 0.28428304), ('asthmatics', 0.2800613), ('subependymal', 0.27060902)]
Word2Vec Dissimilar Diseases: [('subcellular', -0.32582596), ('hyperthyroidism', -0.31490138), ('atherosclerotic', -0.27141875), ('blurred', -0.25252727), ('menopausal', -0.23699947)]
