<a href="https://colab.research.google.com/github/Chetan2326/Diagnosis-from-Textual-Description/blob/main/CDSSM_InfoRetrieval_Stratify_02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import tensorflow as tf
import itertools
import pandas as pd
import math
import random
from collections import defaultdict
import itertools
import re


from sklearn.model_selection import train_test_split
from sklearn.neighbors import KDTree

import keras 
from keras.backend import max as MAX
from keras.layers import Activation, Input, concatenate, dot
from keras.layers.core import Dense, Lambda, Reshape
from keras.layers.convolutional import Convolution1D
from keras.models import Model
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, ReduceLROnPlateau
from keras import backend as bk

import nltk
nltk.download('punkt')
from nltk import ngrams
from nltk.tokenize import word_tokenize
from string import punctuation

punctuation = list(punctuation)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
  try:
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Virtual devices must be set before GPUs have been initialized
    print(e)

1 Physical GPUs, 1 Logical GPUs


In [4]:
data_file = '/content/drive/MyDrive/cdssm/data/Full_Data.csv'
#/content/drive/MyDrive/Full_Data.csv

#Read the data
df=pd.read_csv(data_file,encoding = "ISO-8859-1",on_bad_lines='skip')

#query with not null values
df.loc[df['SPED_COMLTEXT'].isnull(),'SPED_COMLTEXT'] = "#"
df.loc[df['DESCRIPTION'].isnull(),'DESCRIPTION'] = "#"


#Get unique values of ICDs
unique_icds = df['SPDD_ICDCODE'].unique()
print("Number of Unique ICDs are ",len(unique_icds))

Number of Unique ICDs are  11437


In [5]:
df = df.sort_values(by='SPDD_ICDCODE')

In [6]:
# df = df[:20]
df.head

<bound method NDFrame.head of             RCH_PIN                                      SPED_COMLTEXT  \
530619   AMH0115614                        runny nose with nasal block   
530620   AMH0115614                                              cough   
530618   AMH0115614                                         high fever   
530617   AMH0115614                                        sore throat   
1163196  IBS0177952  LEFT SIDED NUMBNESS,,, NECK PAIN MILD...",M47....   
...             ...                                                ...   
2107376  UMC0030349                                            faituge   
2107383  UMC0030349                                            faituge   
2112642  UMC0031282                         pain in the lower abdomen    
2112651  UMC0031282                         pain in the lower abdomen    
2112659  UMC0031282                         pain in the lower abdomen    

              SPDD_ICDCODE DESCRIPTION  
530619              R09.81           #  

In [7]:
df.loc[df['DESCRIPTION'] == ' unspecified"', 'DESCRIPTION'] = '#'


In [8]:
df.groupby(by='DESCRIPTION').count().sort_values(by='SPDD_ICDCODE')

Unnamed: 0_level_0,RCH_PIN,SPED_COMLTEXT,SPDD_ICDCODE
DESCRIPTION,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
"Multiple fractures of ribs, unspecified side, subsequent encounter for fracture with routine healing",2,2,2
Other specified disorders of left middle ear and mastoid,2,2,2
"Ectopic testis, unilateral",2,2,2
Other specified disorders of left middle ear and mastoid in diseases classified elsewhere,2,2,2
Other specified disorders of parathyroid gland,2,2,2
...,...,...,...
Myalgia,51928,51928,51928
Essential (primary) hypertension,55739,55739,55739
"Hyperlipidemia, unspecified",58514,58514,58514
"Vitamin D deficiency, unspecified",59907,59907,59907


In [9]:
#Parameters

MAX_QUERY_LENGTH = 10
MAX_DOCUMENT_LENGTH = 10
K = 300 # Dimensionality of the max-pooling layer. See section 3.4.
L = 128 # Dimensionality of latent semantic space. See section 3.5.
FILTER_LENGTH = 3 # Convolution window size.
NEGATIVE_SAMPLE_SIZE = 4
TRIGRAM_INDICES = None
TOTAL_TRIGRAMS = 0

In [10]:
# Custom data generator : It generates data batch-wise, processes it and sends for training
class ICDRetrivalDataGenerator(tf.keras.utils.Sequence):

    def __init__(self, df, X_col, y_col, mode,
                 batch_size, negative_sample_size, sub_mode = None, TRIGRAM_INDICES = None,
                 TOTAL_TRIGRAMS = 0, shuffle=True,
                 single_query='', unique_document_df=None, model = None):
        self.df = df.copy()
        self.X_col = X_col
        self.y_col = y_col
        self.pin_column = 'RCH_PIN'  # For negative sampling
        self.discarded_words = ['undefined', 'unspecified']  # words to be discarded from queries/documents
        self.replace_characters = [',', ';', '.', '/', ':', '+', '(', ')', '-', '%', '<', '>', '&', '*', '[', ']',
                                    '?', '_', '"', "@"]
        self.shuffle = shuffle
        self.mode = mode  # one out of 'train' or 'predict_on_single'
        self.batch_size = batch_size
        self.negative_sample_size = negative_sample_size
        self.tree = None

        if (self.mode == 'train'):
            self.create_all_possible_trigrams()
            self.map_unique_pins_to_icds()

            self.unique_document_df = self.df[['SPDD_ICDCODE', self.y_col]]
            self.unique_document_df = self.unique_document_df.drop_duplicates(subset=['DESCRIPTION'])

            if (sub_mode == 'kdtree'):
                self.__kdtree(model)

        if (self.mode == 'predict_on_single'):

            self.total_letter_trigrams = TOTAL_TRIGRAMS
            self.trigrams_to_indices = TRIGRAM_INDICES
            self.single_query = single_query
            self.unique_document_df = unique_document_df


    def create_all_possible_trigrams(self):
        # To get all the permutations of ngrams possible
        # One special character
        d = '#'
        # All the alphabets
        for i in range(ord('a'), ord('z') + 1):
            d += chr(i)
        # To get ngrams with 'aaa','bbb',etc
        d = d + d + d
        self.trigrams_to_indices = dict()  # It holds mapping of tricharacters to indices
        # Permutations of d to form ngrams
        ngrams_ = list(itertools.permutations(d, 3))
        i = int(0)
        # Forming dictionary with ngrams and their index
        for ngram in list(ngrams_):
            if ''.join(ngram) not in self.trigrams_to_indices.keys():
                self.trigrams_to_indices[''.join(ngram)] = i
                i += 1
        self.total_letter_trigrams = len(
            self.trigrams_to_indices)  # It should come (26 (characters a-z) + 1 (special character #)^3 = 19683

    # For negative sampling
    def map_unique_pins_to_icds(self):
        self.id_idx = self.df.groupby(self.pin_column)['SPDD_ICDCODE'].apply(list).to_dict()

    # This converts the input words in text form into vectors
    def word_to_vector(self, word):
        word = word.lower()
        word = '#' + word + '#'
        trigrams = list(map(lambda x: ''.join(x), list(ngrams(word, 3))))
        wordvector = [0] * self.total_letter_trigrams
        for w in trigrams:
            try:
                wordvector[self.trigrams_to_indices[w]] = 1
            except:
#                 print("missing trigram " + w)
                pass
        return wordvector

    # Function to convert sentence to list of word vectors
    def sentence_to_vectors(self, sentence, sentence_type='query'):
        if (type(sentence) == float):
            sentence = "###"
        sentence = sentence.lower()
        words = sentence.split(" ")
        max_sentence_length = MAX_QUERY_LENGTH
        if (sentence_type == 'document'):
            max_sentence_length = MAX_DOCUMENT_LENGTH
        words = [x for x in words if x not in self.discarded_words]
        words = words[:max_sentence_length]
        words = words
        word_vectors = []
        for word in words:
            if word:
                for specialChar in self.replace_characters:
                    word = word.replace(specialChar, '')
                for i in range(10):
                    word = word.replace(str(i), '')  # Replace all digits with ''
                vec = self.word_to_vector(word)
                word_vectors.append(vec)
            else:
                vec = self.word_to_vector('#')
                word_vectors.append(vec)

        if (max_sentence_length > len(word_vectors)):
            word_vectors = word_vectors + [[0.0] * self.total_letter_trigrams for x in
                                           range(max_sentence_length - len(word_vectors))]

        word_vectors = [[0.0] * self.total_letter_trigrams] + word_vectors + [
            [0.0] * self.total_letter_trigrams]  # Padding in start and end

        return word_vectors

    def on_epoch_end(self):
        self.df = self.df.sample(frac=1)

    def __kdtree(self, model):
        inp = model.input                          # input placeholder
        outputs = model.layers[-2].output          # all layer outputs
        functors = bk.function([inp], [outputs])   # evaluation functions

        docs = []

        for doc in self.unique_document_df['DESCRIPTION']:
          word_vector = self.sentence_to_vectors(doc, sentence_type='document')
          vec = np.array(word_vector)[np.newaxis,...]
          docs += [functors([0,vec])]
        
        print(np.array(docs).shape)
        # self.tree = KDTree(docs)
        self.tree = KDTree(np.array(docs).reshape(len(self.unique_document_df),128))

        
    def __get_data(self, batches):
        # Generates data containing batch_size samples

        query_vectors = []
        document_vectors = []
        labels = []

        for index, row in batches.iterrows():

            if (self.mode == 'train'):

                query_vector = self.sentence_to_vectors(row[self.X_col])

                # Add positive sample pair
                query_vectors.append(query_vector)
                document_vectors.append(self.sentence_to_vectors(row[self.y_col], sentence_type='document'))
                labels.append(1)

                # Extract negative samples
                pin = row[self.pin_column]
                remove_icds = self.id_idx[pin]
                current_icd = row['SPDD_ICDCODE']
                if (current_icd not in remove_icds):
                    remove_icds.append(current_icd)
                
                # print(self.unique_document_df['SPDD_ICDCODE'])

                available_documents_for_negative_sampling = self.unique_document_df[~self.unique_document_df['SPDD_ICDCODE'].isin(remove_icds)]
                available_documents_for_negative_sampling = available_documents_for_negative_sampling[
                    'DESCRIPTION'].tolist()
                k = self.negative_sample_size
                if (k > len(available_documents_for_negative_sampling)):
                    k = len(available_documents_for_negative_sampling)
                if (k != 0):
                    nagative_samples = random.sample(available_documents_for_negative_sampling, k)
                    negative_document_vectors = [self.sentence_to_vectors(sample, sentence_type='document') for sample
                                                 in nagative_samples]
                    # Add negative samples pairs
                    query_vectors += [query_vector for x in range(k)]
                    document_vectors += negative_document_vectors
                    labels += [0 for x in range(k)]

            elif (self.mode == 'predict_on_single'):

                query_vectors.append(self.sentence_to_vectors(self.single_query))
                document_vectors.append(self.sentence_to_vectors(row[self.y_col], sentence_type='document'))
                labels.append(1)

            else:
                raise ("Improper mode")

        query_vectors = np.array(query_vectors)
        document_vectors = np.array(document_vectors)
        labels = np.array(labels)

        X_batch = [query_vectors, document_vectors]

        return X_batch, labels

    def __getitem__(self, index):
        if (self.mode == 'train'):
            batches = self.df[index * self.batch_size:(index + 1) * self.batch_size]
            X, y = self.__get_data(batches)
        elif (self.mode == 'predict_on_single'):
            self.unique_document_names = self.unique_document_df['SPDD_ICDCODE'].values
            batches = self.unique_document_df[index * self.batch_size:(index + 1) * self.batch_size]  # As only unique documents needed for prediction
            X, y = self.__get_data(batches)
        else:
            raise ('Improper Mode')
        return X, y

    def __len__(self):
        if (self.mode == 'train'):
            return len(self.df) // self.batch_size
        elif (self.mode == 'predict_on_single'):
            return math.ceil(len(self.unique_document_df) // self.batch_size) + 1
        else:
            raise ('Improper Mode')



In [11]:
datagen = ICDRetrivalDataGenerator(df,
                         X_col='SPED_COMLTEXT',
                         y_col='DESCRIPTION',
                         mode='train',
                         batch_size=5,
                         negative_sample_size=0)

TOTAL_TRIGRAMS = datagen.total_letter_trigrams
TRIGRAM_INDICES = datagen.trigrams_to_indices
print(TOTAL_TRIGRAMS)

19683


In [12]:
# Build Model
def create_model():
  query = Input(shape=(MAX_QUERY_LENGTH + 2, TOTAL_TRIGRAMS))
  doc = Input(shape=(MAX_DOCUMENT_LENGTH + 2, TOTAL_TRIGRAMS))

  query_conv = Convolution1D(K, FILTER_LENGTH, padding="valid", input_shape=(MAX_QUERY_LENGTH + 2, TOTAL_TRIGRAMS),
                            activation="tanh", use_bias=False)(query)  # See equation (2).
  query_max = Lambda(lambda x: MAX(x, axis=1), output_shape=(K,))(query_conv)
  query_sem = Dense(L, activation="tanh", input_dim=K)(query_max)  # See section 3.5.

  doc_conv = Convolution1D(K, FILTER_LENGTH, padding="valid", input_shape=(MAX_DOCUMENT_LENGTH + 2, TOTAL_TRIGRAMS),
                          activation="tanh", use_bias=False)(doc)
  doc_max = Lambda(lambda x: MAX(x, axis=1), output_shape=(K,))(doc_conv)
  doc_sem = Dense(L, activation="tanh", input_dim=K)(doc_max)

  cosine_similarities = dot([query_sem, doc_sem], axes=1, normalize=True)  # See equation (4).
  probs = cosine_similarities  # See equation (5).  

  model = Model(inputs=[query, doc], outputs=probs)
  optimizer = keras.optimizers.Adadelta(learning_rate=0.005)
  model.compile(optimizer = optimizer, loss = "binary_crossentropy", metrics=['acc'])
  # model.compile(optimizer="adadelta", loss="binary_crossentropy", metrics=['accuracy'])
  return model

mirrored_strategy = tf.distribute.MirroredStrategy()

with mirrored_strategy.scope():
  model = create_model()
model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 12, 19683)]  0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 12, 19683)]  0           []                               
                                                                                                  
 conv1d (Conv1D)                (None, 10, 300)      17714700    ['input_1[0][0]']                
                                                                                                  
 conv1d_1 (Conv1D)              (None, 10, 300)      17714700    ['input_2[0][0]']                
                                                                                              

In [35]:
input_shape=(MAX_QUERY_LENGTH + 2, TOTAL_TRIGRAMS)

inp = model.input                          # input placeholder
outputs = model.layers[-2].output          # all layer outputs
functors = bk.function([inp], [outputs])   # evaluation functions

# Testing
num_of_docs = 10
test = [np.random.random(input_shape)[np.newaxis,...] for _ in range(num_of_docs)]
# print(np.array(test).shape)
layer_outs = [functors([0,test_]) for test_ in test]
# print(np.array(layer_outs).shape)
# print(layer_outs)
output = np.array(layer_outs).reshape(num_of_docs,128)
# print(output.shape)
# print(output)

tree = KDTree(output)
tree_data, node_index, node_data, node_bounds = tree.get_arrays()

# inp = model.input                                           # input placeholder
# outputs = [layer.output for layer in model.layers]          # all layer outputs
# functor = K.function([inp, K.learning_phase()], outputs )   # evaluation function

# # Testing
# test = np.random.random(input_shape)[np.newaxis,...]
# layer_outs = functor([[test,test], 1.])
# print(layer_outs)

In [40]:
model_name = "model_stratify_new_data_1.h5"
save_model_path = "/content/drive/MyDrive/cdssm/models/"+model_name
model.load_weights(save_model_path)

datagen = ICDRetrivalDataGenerator(df,
                                    X_col='SPED_COMLTEXT',
                                    y_col='DESCRIPTION',
                                    mode='train',
                                    sub_mode='kdtree',
                                    batch_size=5,
                                    negative_sample_size=4,
                                    single_query="",
                                    model=model)
tree = datagen.tree

(11324, 1, 1, 128)


In [42]:
import pickle

tree_name = "first_kdtree.pkl"
tree_path = "/content/drive/MyDrive/cdssm/kdtree/"+tree_name

with open(tree_path, 'wb') as file:
  pickle.dump(tree, file)


In [43]:
with open(tree_path, 'rb') as file:
  tree = pickle.load(file)
tree_data, node_index, node_data, node_bounds = tree.get_arrays()

In [44]:
tree_data.shape, tree_data

((11324, 128),
 array([[ 0.0212833 ,  0.00964101, -0.07100918, ..., -0.00044581,
          0.08052758, -0.02634126],
        [ 0.03133357, -0.01392449, -0.0856846 , ..., -0.00585479,
          0.09355202, -0.03721176],
        [ 0.06557032, -0.10550984, -0.06639732, ..., -0.09118482,
          0.02494656, -0.09683115],
        ...,
        [ 0.05031999,  0.09429108, -0.02349198, ..., -0.02346195,
         -0.06449218, -0.03148569],
        [ 0.04650564,  0.0264884 ,  0.02296255, ..., -0.00799597,
         -0.02186086, -0.02759873],
        [ 0.05065811,  0.00828418,  0.00440481, ..., -0.04292645,
         -0.01578663,  0.00094439]]))

In [45]:
node_index.shape, node_index

((11324,), array([4075, 6779, 2614, ..., 2638,  833, 2600]))

In [46]:
node_data.shape, node_data

((511,),
 array([(    0, 11324, 0, 2.09620554), (    0,  5662, 0, 1.93357537),
        ( 5662, 11324, 0, 1.52363708), (    0,  2831, 0, 1.63597253),
        ( 2831,  5662, 0, 1.55399931), ( 5662,  8493, 0, 1.20626071),
        ( 8493, 11324, 0, 1.42248343), (    0,  1415, 0, 1.43668247),
        ( 1415,  2831, 0, 1.48951501), ( 2831,  4246, 0, 1.41888544),
        ( 4246,  5662, 0, 1.28856969), ( 5662,  7077, 0, 1.06648749),
        ( 7077,  8493, 0, 1.12299364), ( 8493,  9908, 0, 1.31877447),
        ( 9908, 11324, 0, 1.17551653), (    0,   707, 0, 1.25284523),
        (  707,  1415, 0, 1.29629409), ( 1415,  2123, 0, 1.40621439),
        ( 2123,  2831, 0, 1.14536014), ( 2831,  3538, 0, 1.17563074),
        ( 3538,  4246, 0, 1.3292918 ), ( 4246,  4954, 0, 0.99204632),
        ( 4954,  5662, 0, 1.20566087), ( 5662,  6369, 0, 1.00191368),
        ( 6369,  7077, 0, 0.90932925), ( 7077,  7785, 0, 1.03566089),
        ( 7785,  8493, 0, 0.99727691), ( 8493,  9200, 0, 1.05287163),
        ( 9

In [47]:
node_bounds.shape, node_bounds

((2, 511, 128),
 array([[[-0.23240016, -0.15180783, -0.16472344, ..., -0.12546661,
          -0.32772285, -0.15355928],
         [-0.23240016, -0.03257354, -0.16209266, ..., -0.12485632,
          -0.32772285, -0.15355928],
         [-0.09241162, -0.15180783, -0.16472344, ..., -0.12546661,
          -0.17237113, -0.14477174],
         ...,
         [-0.0473935 , -0.0611182 , -0.03725491, ..., -0.06655216,
          -0.10322763, -0.04925575],
         [-0.03191813, -0.01264756, -0.03020702, ..., -0.04584377,
          -0.09809712, -0.07745761],
         [-0.0440845 , -0.02385855, -0.05773157, ..., -0.03039812,
          -0.11423066, -0.066021  ]],
 
        [[ 0.20399837,  0.4304499 ,  0.21230552, ...,  0.2977815 ,
           0.09355202,  0.22005461],
         [ 0.20315823,  0.4304499 ,  0.21230552, ...,  0.2977815 ,
           0.05159216,  0.22005461],
         [ 0.20399837,  0.1560647 ,  0.16161941, ...,  0.11775033,
           0.09355202,  0.08511035],
         ...,
         [ 0.1160

In [None]:
# It takes trained model and input dataframe, and computes top-k documents for each query
def fetch_documents_on_queries(model, dataframe):
  top_k = 10
  actual_documents = []
  top_predicted_documents = []
  datagen = ICDRetrivalDataGenerator(df,
                                      X_col='SPED_COMLTEXT',
                                      y_col='DESCRIPTION',
                                      mode='train',
                                      batch_size=5,
                                      negative_sample_size=4,
                                      single_query="")
  unique_document_df = datagen.unique_document_df

  print("Total queries to evaluate ", len(dataframe))
  for index, row in dataframe.iterrows():
      query = row['SPED_COMLTEXT']
      datagen_for_query = ICDRetrivalDataGenerator(dataframe,
                                                    X_col='SPED_COMLTEXT',
                                                    y_col='DESCRIPTION',
                                                    mode='predict_on_single',
                                                    batch_size=5,
                                                    negative_sample_size=4,
                                                    TRIGRAM_INDICES = TRIGRAM_INDICES,
                                                    TOTAL_TRIGRAMS = TOTAL_TRIGRAMS,
                                                    single_query=query,
                                                    unique_document_df=unique_document_df)
      prediction_probabilities = model.predict_generator(datagen_for_query, verbose=1)
      prediction_probabilities = [item for sublist in prediction_probabilities for item in sublist]
      predicted_top_document_indices = np.argsort(-np.array(prediction_probabilities), kind='mergesort')[:top_k]
      predicted_documents = [datagen_for_query.unique_document_names[x] for x in predicted_top_document_indices]

      actual_documents.append(row['SPDD_ICDCODE'])
      top_predicted_documents.append(predicted_documents)
      print("Processed ", index)

  return actual_documents, top_predicted_documents


# This function computes top-k accuracy based on parameter 'k'
def compute_top_k_accuracy(actual_documents, top_predicted_documents, k):
  total = len(actual_documents)
  correct = 0
  for i in range(total):
      if (actual_documents[i] in top_predicted_documents[i][:k]):
          correct += 1
  accuracy = (correct * 1.0) / total
  return accuracy

In [None]:
train_test_split_size = 0.8

# num_samples = df.shape[0]
# df = df.sample(frac=1).reset_index(drop=True)  # randomly shuffle the complete dataframe
# split_sample_numbers = int(num_samples * train_test_split_size)

# train_df = df.iloc[:split_sample_numbers, :]
# test_df = df.iloc[split_sample_numbers:, :]

train_x, test_x, train_y, test_y = train_test_split(df.iloc[1:,:-1],
                                                    df.iloc[1:,-1],
                                                    stratify=df.iloc[1:,-1],
                                                    train_size=train_test_split_size)

train_df = pd.concat([train_x, train_y], axis=1)
test_df = pd.concat([test_x, test_y], axis=1)

model_name = "model_stratify_new_data_1.h5"
save_model_path = "/content/drive/MyDrive/cdssm/models/"+model_name
checkpoint_path = "/content/drive/MyDrive/cdssm/checkpoint/1"

In [None]:
experiment_mode = 'train' #Either of 'train' or 'test'. Please don't get confused with mode in datagenerator


if(experiment_mode == 'train'):

    num_samples = train_df.shape[0]
    train_validation_split = 0.8
    split_sample_numbers = int(num_samples * train_validation_split)
    

    training_df = train_df.iloc[:split_sample_numbers, :]
    validation_df = train_df.iloc[split_sample_numbers:, :]

    # EarlyStopping
    early_stopping = EarlyStopping(
        monitor='loss',
        patience=5,
        verbose=1,
        restore_best_weights=True
    )

    reduce_lr = ReduceLROnPlateau(
        monitor='loss',
        factor=0.001,
        patience=2,
        verbose=1
    )

    model_checkpoint_callback = ModelCheckpoint(
        filepath=checkpoint_path,
        save_weights_only=True,
        save_freq=1000,
        monitor='loss',
        mode='min',
        save_best_only=True)

    traingen = ICDRetrivalDataGenerator(training_df,
                                        X_col='SPED_COMLTEXT',
                                        y_col='DESCRIPTION',
                                        mode='train',
                                        batch_size=32,
                                        negative_sample_size=NEGATIVE_SAMPLE_SIZE)

    validgen = ICDRetrivalDataGenerator(validation_df,
                                        X_col='SPED_COMLTEXT',
                                        y_col='DESCRIPTION',
                                        mode='train',
                                        batch_size=32,
                                        negative_sample_size=NEGATIVE_SAMPLE_SIZE)

    history = model.fit(traingen,
                                  validation_data=validgen,
                                  validation_freq=1,
                                  epochs=1,
                                  verbose=1,
                                  class_weight={0:1,1:5},
                                  callbacks=[early_stopping, reduce_lr, model_checkpoint_callback])

    history.history
    model.save_weights(save_model_path)
    input_shape=(MAX_QUERY_LENGTH + 2, TOTAL_TRIGRAMS)

    inp = model.input                          # input placeholder
    outputs = model.layers[-2].output          # all layer outputs
    functors = bk.function([inp], [outputs])   # evaluation functions

    
    print("Model Saved")  # Download model as soon as the training finished. You can find the trained model weights on the left folder pannel
    
elif (experiment_mode == 'test'):

    # Load model
    model.load_weights(save_model_path)
    print("Loaded")

    test_df = test_df.sample(frac=1).reset_index(drop=True)

    # test_df = test_df.iloc[:100,:]
    actual_documents, top_predicted_documents = fetch_documents_on_queries(model, test_df.sample(n=100))
    myfile = open('Accuracy_w2v.txt','a')
    top_3_accuracy = compute_top_k_accuracy(actual_documents, top_predicted_documents, 3)
    print("Top 3 accuracy for w2v with ", model_name, " : ",top_3_accuracy, file=myfile)
    top_5_accuracy = compute_top_k_accuracy(actual_documents, top_predicted_documents, 5)
    print("Top 5 accuracy for w2v", model_name, " : ", top_5_accuracy, file=myfile)
    myfile.close()
