In [10]:
from joblib import dump, load
# clf = load('../CRED_application/Best_CRED_trained_model_new_data_svm.joblib')
# clf = load('CDR_trained_model_xgb.joblib')
clf = load('../CRED_application/CRED_trained_model_new_data_xgb.joblib')

# clf = load('CDR_trained_model.joblib')


from transformers import AutoTokenizer, BertModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained( "dmis-lab/biobert-base-cased-v1.1")
model = BertModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
biobert=model
model.to(device)

In [14]:
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

def remove_stopwords(text):
        #nltk.download('stopwords')
        stop_words = set(stopwords.words('english'))
        
        tokens = word_tokenize(text)
        tokens = [word for word in tokens if word.lower() not in stop_words]
        return ' '.join(tokens)

In [15]:
def get_specific_token_embeddings(sentence):
        # 1. Tokenize the input sentence
        inputs = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)

        # 2. Find the indices of "@GeneSrc" and "@DiseaseTgt$"
        tokenized_sentence = tokenizer.tokenize(sentence)
        gene_src_token = tokenizer.tokenize("@GeneSrc$")
        disease_tgt_token = tokenizer.tokenize("@DiseaseTgt$")

        gene_src_indices = [i for i, token in enumerate(tokenized_sentence) if token in gene_src_token]
        disease_tgt_indices = [i for i, token in enumerate(tokenized_sentence) if token in disease_tgt_token]

        # Run the sentence through BioBERT
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings = outputs['last_hidden_state'][0]  # Extracting embeddings for the whole sentence

        # 3. Retrieve the embeddings for the surrounding tokens
        context_range = 2

        def get_context_embeddings(indices):
            context_embeddings = []
            for idx in indices:
                start = max(0, idx - context_range)
                end = min(idx + context_range + 1, len(tokenized_sentence))
                context = embeddings[start:end]
                context_embeddings.append(context)
            return torch.cat(context_embeddings).view(-1, 768)

        gene_src_embeddings = get_context_embeddings(gene_src_indices)
        disease_tgt_embeddings = get_context_embeddings(disease_tgt_indices)

        # 4. Compute the average of the embeddings
        avg_gene_src_embedding = torch.mean(gene_src_embeddings, dim=0)
        avg_disease_tgt_embedding = torch.mean(disease_tgt_embeddings, dim=0)

        combined_embedding = torch.cat([avg_gene_src_embedding, avg_disease_tgt_embedding], dim=0)
        combined_embedding_np = combined_embedding.cpu().numpy().reshape(1, -1)  # Convert tensor to NumPy array and reshape to 2D

        if np.isnan(combined_embedding_np).any():
            print(sentence)

        return combined_embedding_np


In [16]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.base import BaseEstimator, ClassifierMixin
from tqdm import tqdm

#Code for perturbation and generating importance scores

def perturb_sentence(sentence, tokenizer):
    """Perturb the input sentence by masking out one word at a time."""
    words = tokenizer.tokenize(sentence)
    perturbed_sentences = []

    for i in range(len(words)):
        perturbed = words[:i] + ['[MASK]'] + words[i+1:]
        perturbed_sentences.append(tokenizer.convert_tokens_to_string(perturbed))
    
    return perturbed_sentences

def compute_influences(sentence, clf, tokenizer, biobert):
    """Compute the influence of each word on the classifier's prediction."""
    original_confidences=[]
    perturbed_confidences=[]
    perturbed_sentences = perturb_sentence(sentence, tokenizer)
    original_embedding = get_specific_token_embeddings(sentence).reshape(1, -1)
    original_confidence = clf.predict_proba(original_embedding)[0][1]
    original_confidences.append(original_confidence)
    
    # Add prediction of model for original sentence
    original_prediction = clf.predict(original_embedding)[0]
    

    influences = []
    org_probs = []
    perturbed_probs = []
    for perturbed in perturbed_sentences:
        perturbed_embedding = get_specific_token_embeddings(perturbed).reshape(1, -1)
        perturbed_confidence = clf.predict_proba(perturbed_embedding)[0][1]
        perturbed_confidences.append(perturbed_confidence)
        
        influence=(max(perturbed_confidence-0.5, original_confidence-0.5, 0))*(original_confidence-perturbed_confidence)
        #influence = abs(original_confidence - perturbed_confidence)
        #influence =  perturbed_confidence - original_confidence
        influences.append(influence)
        org_probs.append(original_confidence)
        perturbed_probs.append(perturbed_confidence)
        

    return influences, org_probs, perturbed_probs, original_prediction

def rank_words_by_influence(sentence, influences, tokenizer, org_probs, perturbed_probs):
    #print(sentence)
    """Rank words by their influence."""
    words = tokenizer.tokenize(sentence)
    all_words = []
    all_influence_scores = []
    all_org_prob_scores = []
    all_perturbed_prob_scores = []
    importance = []
    org_prob_importance = []
    perturbed_prob_importance = []
    word = ""
    old_word = None
    for influence, org_prob, perturbed_prob, token in zip(influences, org_probs, perturbed_probs, words):
        if len(token) == 1:
            continue
        if token.startswith("#"):
            new_token = token.replace('#', '')
            word += new_token.strip()
            importance.append(influence)
            org_prob_importance.append(org_prob)
            perturbed_prob_importance.append(perturbed_prob)
        else:
            if old_word is None:
                all_words.append(word)
                all_influence_scores.append(influence)
                all_org_prob_scores.append(org_prob)
                all_perturbed_prob_scores.append(perturbed_prob)
            else:
                if old_word.startswith("#"):
                    all_words.append(word)
                    word = "" + token
                    influence_score = sum(importance) / len(importance)
                    org_prob_score = sum(org_prob_importance) / len(org_prob_importance)
                    perturbed_prob_score = sum(perturbed_prob_importance) / len(perturbed_prob_importance)
                    all_influence_scores.append(influence_score)
                    all_org_prob_scores.append(org_prob_score)
                    all_perturbed_prob_scores.append(perturbed_prob_score)
                else:
                    all_words.append(old_word)
                    word = "" + token
                    importance.append(influence)
                    org_prob_importance.append(org_prob)
                    perturbed_prob_importance.append(perturbed_prob)
                    all_influence_scores.append(old_imp)
                    all_org_prob_scores.append(old_org_prob)
                    all_perturbed_prob_scores.append(old_perturbed_prob)
            old_word = token
            old_imp = influence
            old_org_prob = org_prob
            old_perturbed_prob = perturbed_prob
    
    ranked_words = [word for _, word in sorted(zip(all_influence_scores, all_words), reverse=True)]
    return ranked_words, all_words, all_influence_scores, all_org_prob_scores, all_perturbed_prob_scores


def aggregate_word_importance(words, imp_scores, org_prob_scores, perturbed_prob_scores):
    """Aggregate scores for words into a single dictionary."""
    word_importance = {}

    for word, imp_score, org_prob_score, perturbed_prob_score in zip(words, imp_scores, org_prob_scores, perturbed_prob_scores):
        if len(word)>2:
            if word in word_importance:
                # Update existing entry with max scores
                word_importance[word] = [max(word_importance[word][0], imp_score),
                                         max(word_importance[word][1], org_prob_score),
                                         max(word_importance[word][2], perturbed_prob_score)]
            else:
                # Create new entry for the word
                word_importance[word] = [imp_score, org_prob_score, perturbed_prob_score]

    return word_importance


def save_word_importance_to_file(word_importance, row, original_prediction, filename='word_importance_score_cred_withpred_traindata_xgb.tsv'):
    """Save word importance scores to a file."""
    with open(filename, 'a') as file:
        if file.tell() == 0:  # Check if file is empty
            file.write("PMID\tgeneid\tdiseaseid\tword\timportance_score\torg_prob_score\tperturbed_prob_score\toriginal_prediction\n")
        for word, scores in word_importance.items():
            # Write index, id1, id2, word, importance score, org_prob_score, perturbed_prob_score, and original prediction
            file.write(f"{row['index']}\t{row['id1']}\t{row['id2']}\t{word}\t{scores[0]}\t{scores[1]}\t{scores[2]}\t{original_prediction}\n")


def main(row):
    sentence = row['sentence']
    influences, org_probs, perturbed_probs, original_prediction = compute_influences(sentence, clf, tokenizer, biobert)
    
    ranked, words, imp_scores, org_prob_scores, perturbed_prob_scores = rank_words_by_influence(sentence, influences, tokenizer, org_probs, perturbed_probs)
    ranked = [i for i in ranked if len(i) > 1]
    #print("Words ranked by influence:", ranked)
    word_importance = aggregate_word_importance(words, imp_scores, org_prob_scores, perturbed_prob_scores)
    save_word_importance_to_file(word_importance, row, original_prediction)
    return ranked, word_importance


# Apply the main function for each value in the "sentence" column
train_df = pd.read_csv('new_train_data', delimiter='\t')
train_df['sentence'] = train_df['sentence'].apply(remove_stopwords)
causal_train_df=train_df[train_df["label"]==1]
causal_train_df_unique = causal_train_df.drop_duplicates(subset=['index', 'id1', 'id2'])


# causal_test_df=test_df[test_df["label"]==1]
# causal_df_unique = causal_test_df.drop_duplicates(subset=['index', 'id1', 'id2'])

# samp_abst=causal_df_unique[causal_df_unique["index"]==25064704]

for _, row in tqdm(causal_train_df_unique.iterrows()):
    ranked, word_importance = main(row)


132it [08:04,  3.67s/it]


In [None]:
model