In [7]:
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
from tqdm import tqdm
import nltk
from nltk.corpus import stopwords
from sklearn.svm import SVC
from sklearn.metrics import classification_report
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier
from catboost import CatBoostClassifier
import pickle
from joblib import dump, load
from interpretation_code import interpretation
# from k_fold_cv import k_fold



class model_trainer:

    def __init__(self, train_path, test_path, model_name):
        self.train_df=self.load_dataset(train_path)
        self.test_df=self.load_dataset(test_path)
        self.model_name=self.load_model(model_name)
        self.tokenizer=self.load_tokenizer(model_name)
#         self.model_name.resize_token_embeddings(len(self.tokenizer))
        

    def load_model(self,model_name):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        biobert = BertModel.from_pretrained(model_name).to(self.device)
        print("BioBERT model loaded")
        return biobert

    def load_tokenizer(self,model_name):
        tokenizer = BertTokenizer.from_pretrained(model_name)
#         tokenizer.add_tokens(["GeneSrc", "DiseaseTgt", "causative", "causal", "cause", "causing", "caused"])
        print("Tokenizer loaded")
        return tokenizer

    def load_dataset(self,data_path):
        df = pd.read_csv(data_path, delimiter='\t')
        return df
    
    def remove_stopwords(self,text):
        #nltk.download('stopwords')
        stop_words = set(stopwords.words('english'))
        
        tokens = nltk.word_tokenize(text)
        tokens = [word for word in tokens if word.lower() not in stop_words]
        return ' '.join(tokens)
    
    
    def create_dataset(self):
        # Apply the remove_stopwords function to the 'sentence' column
        self.train_df['sentence'] = self.train_df['sentence'].apply(self.remove_stopwords)
        self.test_df['sentence'] = self.test_df['sentence'].apply(self.remove_stopwords)

        X_train = self.train_df['sentence'].tolist()
        y_train = self.train_df['label'].tolist()

        X_test = self.test_df['sentence'].tolist()
        y_test = self.test_df['label'].tolist()
        
        print("Dataset created")
        
        return X_train,y_train,X_test,y_test
    
    def get_specific_token_embeddings(self,sentence):
        # 1. Tokenize the input sentence
        inputs = self.tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device)

        # 2. Find the indices of "@GeneSrc" and "@DiseaseTgt$"
        tokenized_sentence = self.tokenizer.tokenize(sentence)
        gene_src_token = self.tokenizer.tokenize("@GeneSrc$")
        disease_tgt_token = self.tokenizer.tokenize("@DiseaseTgt$")

        gene_src_indices = [i for i, token in enumerate(tokenized_sentence) if token in gene_src_token]
        disease_tgt_indices = [i for i, token in enumerate(tokenized_sentence) if token in disease_tgt_token]

        # Run the sentence through BioBERT
        with torch.no_grad():
            #print(self.model_name)
            outputs = self.model_name(**inputs)
        embeddings = outputs['last_hidden_state'][0]  # Extracting embeddings for the whole sentence

        # 3. Retrieve the embeddings for the surrounding tokens
        context_range = 2

        def get_context_embeddings(indices):
            context_embeddings = []
            for idx in indices:
                start = max(0, idx - context_range)
                end = min(idx + context_range + 1, len(tokenized_sentence))
                context = embeddings[start:end]
                context_embeddings.append(context)
            return torch.cat(context_embeddings).view(-1, 768)

        gene_src_embeddings = get_context_embeddings(gene_src_indices)
        disease_tgt_embeddings = get_context_embeddings(disease_tgt_indices)

        # 4. Compute the average of the embeddings
        avg_gene_src_embedding = torch.mean(gene_src_embeddings, dim=0)
        avg_disease_tgt_embedding = torch.mean(disease_tgt_embeddings, dim=0)

        combined_embedding = torch.cat([avg_gene_src_embedding, avg_disease_tgt_embedding], dim=0)
        combined_embedding_np = combined_embedding.cpu().numpy().reshape(1, -1)  # Convert tensor to NumPy array and reshape to 2D

        if np.isnan(combined_embedding_np).any():
            print(sentence)

        return combined_embedding_np

        
        
    def generate_embeddings(self):
        X_train,y_train,X_test,y_test=self.create_dataset()
        try:
            X_train_embeddings = np.vstack([self.get_specific_token_embeddings(sentence) for sentence in tqdm(X_train)])
            X_test_embeddings = np.vstack([self.get_specific_token_embeddings(sentence) for sentence in tqdm(X_test)])
           
        except Exception as e:
            print(e)
            pass
        print("Embeddings generated")
        return X_train_embeddings,X_test_embeddings
    
    def drop_null_embeddings(self,X_train_embeddings,X_test_embeddings,y_train,y_test):
        X_train_embeddings=pd.DataFrame(X_train_embeddings).dropna()
        X_test_embeddings=pd.DataFrame(X_test_embeddings).dropna()

        test_ind=[i for i in range(0,len(X_test_embeddings)) if i not in X_test_embeddings.index]
        train_ind=[i for i in range(0,len(X_train_embeddings)) if i not in X_train_embeddings.index]
        for i in train_ind:
            y_train.pop(i)
        for i in test_ind:
            y_test.pop(i)
            
        return X_train_embeddings,X_test_embeddings,y_train,y_test
    
    def svm_classifiation(self, X_train_embeddings,y_train,X_test_embeddings,y_test):
        print("Doing classification using SVM...")
        clf = SVC(kernel='poly', degree=10, probability=True, C=10, class_weight={0:1, 1:30})
        clf.fit(X_train_embeddings, y_train)

        # Predict and Evaluate
        y_pred = clf.predict(X_test_embeddings)
        print("Results of SVM classifier are: ")
        print(classification_report(y_test, y_pred))
        
        
    def xg_boost_classification(self, X_train_embeddings,y_train,X_test_embeddings,y_test):
        print("Doing classification using XG Boost...")
        
        clf = xgb.XGBClassifier(scale_pos_weight=500, max_depth=40, learning_rate=0.1, n_estimators=400, gamma=0.1)
        clf.fit(X_train_embeddings, y_train)

        # Get probabilities
        y_prob = clf.predict_proba(X_test_embeddings)

        # Predict and Evaluate
        y_pred = clf.predict(X_test_embeddings)
        
        # Save the model to a file
        dump(clf, 'CDR_trained_model_xgb.joblib') 

        print("Results of XG Boost are: ")
        print(classification_report(y_test, y_pred))
        
    def random_forest(self, X_train_embeddings,y_train,X_test_embeddings,y_test):
        print("Doing classification using Random Forest...")
        

        # Instantiate the Random Forest Classifier
        clf = RandomForestClassifier(n_estimators=200, max_depth=50, max_leaf_nodes= 500, n_jobs= 5, max_features="sqrt", class_weight={0:1, 1:30}, verbose=True)
        clf.fit(X_train_embeddings, y_train)

        # Predict and Evaluate
        y_pred = clf.predict(X_test_embeddings)
        print("Results of Random Forest are: ")
        print(classification_report(y_test, y_pred))
        

        
    def interpretation_call(self):

        causal_test_df=self.test_df[self.test_df["label"]==1]
        causal_df_unique = causal_test_df.drop_duplicates(subset=['index', 'id1', 'id2'])

        samp_abst=causal_df_unique[causal_df_unique["index"]==25064704]

        # Load the model from a file
        clf = load('/home/ubuntu/CRED_application/CRED_trained_model_new_data.joblib')

        model_name = "dmis-lab/biobert-base-cased-v1.1"
        biobert = BertModel.from_pretrained(model_name).to(self.device)
        tokenizer = BertTokenizer.from_pretrained(model_name)

        for _, row in tqdm(causal_df_unique.iterrows()):
            ranked, word_importance = interpretation(row, clf, tokenizer, biobert, self.get_specific_token_embeddings)

    
    
def main():       
    mod_tr=model_trainer('new_train_data', 'test_data', "dmis-lab/biobert-base-cased-v1.1")
    
    #To generate embeddings uncomment the below line
#     X_train_embeddings,X_test_embeddings= mod_tr.generate_embeddings()
    
#     Reading the training embeddings from the text file
    X_train,y_train,X_test,y_test = mod_tr.create_dataset()
#     X_train_embeddings = pd.read_csv('/home/ubuntu/CRED_application/X_train_embeddings_new_data.txt', sep='\t', header=None)
#     X_test_embeddings = pd.read_csv('/home/ubuntu/CRED_application/X_test_embeddings.txt', sep='\t', header=None)
    X_train_embeddings = pd.read_csv('X_train_embeddings_new_data.txt', sep='\t', header=None)
    X_test_embeddings = pd.read_csv('X_test_embeddings.txt', sep='\t', header=None)

    X_train_embeddings,X_test_embeddings,y_train,y_test = mod_tr.drop_null_embeddings(X_train_embeddings,X_test_embeddings,y_train,y_test)
    print("Embeddings generated")
    
    mod_tr.svm_classifiation(X_train_embeddings,y_train,X_test_embeddings,y_test)
    mod_tr.xg_boost_classification(X_train_embeddings,y_train,X_test_embeddings,y_test)
    mod_tr.random_forest(X_train_embeddings,y_train,X_test_embeddings,y_test)
    
    #For Interpretaion uncomment the below 2 lines
#     print("Interpretaion started. It may take upto 10 minutes...")
#     mod_tr.interpretation_call()
    
    #For 4-fold cross validation uncomment the below 2 lines
#     print("4-fold cross validation started. It may take few minutes...")
#     k_fold(mod_tr.train_df, mod_tr.test_df, mod_tr.get_specific_token_embeddings)


#     # Saving the training embeddings to a text file
#     pd.DataFrame(X_train_embeddings).to_csv('X_train_embeddings_without_extra_tokens.txt', sep='\t', index=False, header=False)

#     # Saving the test embeddings to a text file
#     pd.DataFrame(X_test_embeddings).to_csv('val_embeddings_without_extra_tokens.txt', sep='\t', index=False, header=False)

              
    

    
main()    


BioBERT model loaded
Tokenizer loaded
Dataset created
Embeddings generated
Doing classification using SVM...
Results of SVM classifier are: 
              precision    recall  f1-score   support

           0       0.94      0.98      0.96       561
           1       0.61      0.35      0.45        54

    accuracy                           0.92       615
   macro avg       0.78      0.67      0.70       615
weighted avg       0.91      0.92      0.91       615

Doing classification using XG Boost...


  if is_sparse(dtype):
  is_categorical_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
  if is_categorical_dtype(dtype):
  return is_int or is_bool or is_float or is_categorical_dtype(dtype)
  if is_sparse(dtype):
  is_categorical_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
  if is_categorical_dtype(dtype):
  return is_int or is_bool or is_float or is_categorical_dtype(dtype)
  if is_sparse(dtype):
  is_categorical_dtype(dtype) or is_pa_ext_categorical_dtype(dtype)
  if is_categorical_dtype(dtype):
  return is_int or is_bool or is_float or is_categorical_dtype(dtype)


Results of XG Boost are: 
              precision    recall  f1-score   support

           0       0.94      0.93      0.93       561
           1       0.34      0.37      0.35        54

    accuracy                           0.88       615
   macro avg       0.64      0.65      0.64       615
weighted avg       0.89      0.88      0.88       615

Doing classification using Random Forest...


[Parallel(n_jobs=5)]: Using backend ThreadingBackend with 5 concurrent workers.
[Parallel(n_jobs=5)]: Done  40 tasks      | elapsed:    1.5s
[Parallel(n_jobs=5)]: Done 190 tasks      | elapsed:    7.0s


Results of Random Forest are: 
              precision    recall  f1-score   support

           0       0.93      0.95      0.94       561
           1       0.37      0.31      0.34        54

    accuracy                           0.89       615
   macro avg       0.65      0.63      0.64       615
weighted avg       0.89      0.89      0.89       615



[Parallel(n_jobs=5)]: Done 200 out of 200 | elapsed:    7.4s finished
[Parallel(n_jobs=5)]: Using backend ThreadingBackend with 5 concurrent workers.
[Parallel(n_jobs=5)]: Done  40 tasks      | elapsed:    0.0s
[Parallel(n_jobs=5)]: Done 190 tasks      | elapsed:    0.1s
[Parallel(n_jobs=5)]: Done 200 out of 200 | elapsed:    0.1s finished
