In [15]:
import sys
sys.path.insert(0, "ov-predict/src/")

In [16]:
import numpy as np
from model.lstm import buildModel
from api.model_loader import init_embedding
from api.model_loader import predict_outcome
from preprocessing.InputHelper import InputHelper

In [17]:
NODEVEC_DIM=128
PUBMED_DIM=200
NUM_CLASSES=0
EMB_FILE='../../core/prediction/graphs/nodevecs/ndvecs.cv.128-10-0.1-0.9-both-0-false.merged.vec'
DATAFILE='../../core/prediction/sentences/train.tsv'
TESTFILE='../../core/prediction/sentences/test.tsv'
SAVED_MODEL_FILE='ov-predict/saved_models/model.h5'
FOLLOWUP_ATTRIB_NAME='4087191'
VAL_DIMENSIONS=5
MAXLEN=50

In [18]:
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

In [19]:
def isFollowupNode(token):
    parts = token.split(':')
    nodename = parts[1]
    if nodename == FOLLOWUP_ATTRIB_NAME:
        return True
    return False
    
def followupAttribOccurrenceIndex(attribseq):
    for i in range(len(attribseq)):
        token = attribseq[i]
        if isFollowupNode(token) == True:
            return i

    return -1

In [20]:
'''
To make the prediction workflow work, we need to
first load the word vectors in memory and then dynamically change the vectors.
For example, if in arm (document) D1, the followup time was 't_0',
and suppose in the same document, we want to change it to 't', then
1. Find the vector from the dictionary (say v) with the value closest to 't' (the value for which we want to predict).
2. Replace the context part of v with the context vector of the current arm.
3. Put the value of 't' in place of 't_0' in the vector.

While working with a pre-trained model, the index of a word in the vocabulary
represents a context vector for a feature, e.g., the vocabulary index 'i' may
refer to the followup attribute for document D1. At run-time, we simply
replace this i-th vector with the one mentioned above. Once prediction is done,
we revert back to the original vector. This ensures that we don't need to
insert new words in the vocabulary.
'''

class FollowupAnalyzer:
    
    def constructFollowupNodes(self):
        self.followupNodes = []
        for token in self.inpH.pre_emb:
            parts = token.split(':')
            nodename = parts[1]
            if nodename == FOLLOWUP_ATTRIB_NAME and is_number(parts[2])==True:
                self.followupNodes.append(token)
    
    def __init__(self, embfile):
        self.inpH = init_embedding(embfile)
        
        self.model = buildModel(NUM_CLASSES, self.inpH.vocab_size,
                                self.inpH.embedding_matrix.shape[1],
                                MAXLEN, self.inpH.embedding_matrix)
        self.model.summary()
        self.model.load_weights(SAVED_MODEL_FILE)
        
        self.constructFollowupNodes()
    
    def closestFollowupNode(self, value): #value: float
        mindiff = 10000
        minIndex = 0
        
        i = 0
        for n in self.followupNodes:
            x = float(n.split(':')[2])
            diff = abs(value - x)
            if diff < mindiff:
                mindiff = diff
                minIndex = i
            i+=1            
        return self.followupNodes[minIndex]
    
    #change the vector of an instance node so that the embedding layer will then
    #use the modified instance vec for making predictions
    def modifyNodeInstanceVec(self, node, value): #node:string -- an entire node <Type>:<AttribId>:<ArmId>
        # get the closest node to the given value
        matchedNode = self.closestFollowupNode(value)
        #print ("**Matched Node**: {}, value = {}".format(matchedNode, self.inpH.pre_emb[matchedNode][-VAL_DIMENSIONS]))
        
        # Get the vector of the current node (from the arm) and also that of
        # the closest node (node definition). The former is a context vector,
        # the latter is a node defintion vector.
        attrvec = self.inpH.pre_emb[matchedNode]
        instvec = []
        
        #replace the nodevec part of instvec with attrvec
        for i in range(NODEVEC_DIM):
            instvec.append(float(attrvec[i]))

        #context part comes from the current instance
        for i in range(NODEVEC_DIM, NODEVEC_DIM+PUBMED_DIM+VAL_DIMENSIONS):
            #instvec.append(0)
            instvec.append(float(self.inpH.pre_emb[node][i]))
            
        instvec_array = np.asarray(instvec)
            
        instvec_array[-VAL_DIMENSIONS] = value #new followup value
        self.inpH.pre_emb[node] = instvec_array # modified instvec
                
    def revertNodeInstanceVec(self, node, subvec, value):
        i=0
        for x in subvec:
            self.inpH.pre_emb[node][i] = x
            i+=1 
            
        self.inpH.pre_emb[node][-VAL_DIMENSIONS] = value # reverting back        

In [21]:
'''
A token is of the form <TYPE>:ATTRIBID:<DOC-ARM>. The DOC-ARM part of the
token will be useful to get the exact numerical value of an attribute (if
the attribute is numeric, which is true for the follow-up attribute type).

For this, we need to find out the vector from the embedding file corresponding
to the given token. The first component of the last 5 numbers from the vector gives its value.
E.g. 
O:4087191:Schnoll_2019.pdf_1 0.19425800442695618... <52.0> 0.0 0.0 0.0 0.0
We know from the above line that the value of the followup duration is 52 weeks.
'''

def geTrueFollowupValue(token, inpH):
    vec = inpH.pre_emb[token] # vec is a numnp array
    return vec[-VAL_DIMENSIONS]

In [22]:
def getInstanceVecInfo(followupAnalyzer, instance):
    followupAttribIndex = followupAttribOccurrenceIndex(instance)
    if followupAttribIndex < 0:
        return None, 0, None
    node = instance[followupAttribIndex]

    #get the true value
    valueInData = geTrueFollowupValue(node, followupAnalyzer.inpH)
    
    instance_without_followup = []
    i=0
    for x in instance:
        if not i==followupAttribIndex:
            instance_without_followup.append(x)
        i+=1
        
    return node, valueInData, instance_without_followup

def getNodeVec(vec):
    subvec = []
    for i in range(NODEVEC_DIM):
        subvec.append(vec[i])
    return subvec

In [23]:
def predictionOnInstance(followupAnalyzer, avpsequence):

    results=[]
    
    #predict for true followup    
    node, valueInData, instance_without_followup = getInstanceVecInfo(followupAnalyzer, avpsequence)    
    if node==None:
        return
    
    ndvec = getNodeVec(followupAnalyzer.inpH.pre_emb[node])
    
    #print("Original instance-vec: {} {} {} {}".format(node, followupAnalyzer.inpH.pre_emb[node][0:5],
    #                                  followupAnalyzer.inpH.pre_emb[node][200:205],
    #                                  followupAnalyzer.inpH.pre_emb[node][-VAL_DIMENSIONS]))
    
    #print ('Attribute Instance with followup: {} (length = {})'.format(avpsequence, len(avpsequence)))
    predicted_val = predict_outcome(followupAnalyzer.inpH, followupAnalyzer.model, avpsequence)
    #print ('Predicted outcome with followup of {} weeks = {}'.format(valueInData, str(predicted_val[0])))
    results.append(predicted_val[0])

    #print ('Attribute Instance w/o followup: {} (length={})'.format(instance_without_followup, len(instance_without_followup)))
    predicted_val = predict_outcome(followupAnalyzer.inpH, followupAnalyzer.model, instance_without_followup)
    #print ('Predicted outcome without followup attribute {} = {}'.format(valueInData, str(predicted_val[0])))
    results.append(predicted_val[0])

    #print("Original: {}".format(followupAnalyzer.inpH.pre_emb[node][-VAL_DIMENSIONS-5:]))
    
    #predict for pseudo-data
    #for t in range(8, 104, 8):
    for t in range(8, 52, 4):
        #call to this function changes the vector for the embedding layer
        followupAnalyzer.modifyNodeInstanceVec(node, t)
        #print("Modified: {} {} {}".format(followupAnalyzer.inpH.pre_emb[node][0:5],
        #                                  followupAnalyzer.inpH.pre_emb[node][200:205],
        #                                  followupAnalyzer.inpH.pre_emb[node][-VAL_DIMENSIONS]))
            
        predicted_val = predict_outcome(followupAnalyzer.inpH, followupAnalyzer.model, avpsequence)
        #print ('Predicted outcome with followup of {} weeks = {}'.format(t, str(predicted_val[0])))

        #revert back after every change
        followupAnalyzer.revertNodeInstanceVec(node, ndvec, float(valueInData))
        #print("Reverted: {} {} {}".format(followupAnalyzer.inpH.pre_emb[node][0:5],
        #                                  followupAnalyzer.inpH.pre_emb[node][200:205],
        #                                  followupAnalyzer.inpH.pre_emb[node][-VAL_DIMENSIONS]))
        results.append(predicted_val[0])
    
    print(results)        

In [24]:
def isIntervention(token):
    if token.split(':')[0]=='I':
        return True
    return False
    
#Function to filter out only the interventions from an avp sequence
def filterInterventions(avpseq):
    interventionseq = []
    for x in avpseq:
        if isFollowupNode(x)==True or isIntervention(x)==True:
            interventionseq.append(x)
    return ' '.join(interventionseq)            

In [25]:
# Load the data as two matrices - X and Y
def getTsvData(filepath):
    print("Loading data from " + filepath)
    x = []

    # positive samples from file
    for line in open(filepath):
        l = line.strip().split("\t")
        x.append(l[0].split(' '))

    return x

In [26]:
#DATA_INSTANCE='C:5579689:18 I:3675717:1 C:5579088:35 I:3673272:1 O:4087191:52.0'
followupAnalyzer = FollowupAnalyzer(EMB_FILE)

converting words to ids...
Collecting node names...
Collected node names...
Converting words to ids...
Finished converting words to ids...
vocab size = 38892
Loading W2V data...
skipping word M:6080688:When_Smokers_QuitThe_Health_Bene
skipping word M:6080688:lea
skipping word SO:6080714:quali
skipping word R:6080719:Two_hundred_and_
loaded word2vec for 38891 nodes
5 words out of 38892 not found
DEBUG: shape of embedding: (38892, 333)
DEBUG: include_wordvecs = False
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_2 (Embedding)      (None, 50, 333)           12951036  
_________________________________________________________________
bidirectional_2 (Bidirection (None, 128)               203776    
_________________________________________________________________
output_vals (Dense)          (None, 1)                 129       
Total params: 13,154,941
Trainable params: 20

In [27]:
text_instances = getTsvData(DATAFILE) #y's are unused
text_instances_interventions_only = []

for instance in text_instances:
    filteredInstance = filterInterventions(instance) # work only with interventions
    text_instances_interventions_only.append(filteredInstance)

Loading data from ../../core/prediction/sentences/train.tsv


In [28]:
#collect instances with followups
followup_datainstances = []
for avpseq in text_instances_interventions_only:
    avpseq = avpseq.split(' ') # string to list of tokens
    if not len(avpseq) > 1:
        continue
    node, valueInData, instance_without_followup = getInstanceVecInfo(followupAnalyzer, avpseq)    
    if not node==None:
        followup_datainstances.append(avpseq)
        
#for avpseq in followup_datainstances[:4]:         
for avpseq in followup_datainstances:    
    predictionOnInstance(followupAnalyzer, avpseq)

[20.325869, 19.998446, 20.325869, 20.325869, 20.325869, 20.325869, 20.325869, 20.325869, 20.325869, 20.325869, 20.325869, 20.325869, 20.325869]
[17.644108, 15.509278, 17.644108, 17.644108, 17.644108, 17.644108, 17.644108, 17.644108, 17.644108, 17.644108, 17.644108, 17.644108, 17.644108]
[15.478878, 12.91596, 15.478878, 15.478878, 15.478878, 15.478878, 15.478878, 15.478878, 15.478878, 15.478878, 15.478878, 15.478878, 15.478878]
[9.477763, 9.085609, 9.477763, 9.477763, 9.477763, 9.477763, 9.477763, 9.477763, 9.477763, 9.477763, 9.477763, 9.477763, 9.477763]
[17.100475, 17.075266, 17.100475, 17.100475, 17.100475, 17.100475, 17.100475, 17.100475, 17.100475, 17.100475, 17.100475, 17.100475, 17.100475]
[9.630969, 9.146614, 9.630969, 9.630969, 9.630969, 9.630969, 9.630969, 9.630969, 9.630969, 9.630969, 9.630969, 9.630969, 9.630969]
[16.875217, 13.29198, 16.875217, 16.875217, 16.875217, 16.875217, 16.875217, 16.875217, 16.875217, 16.875217, 16.875217, 16.875217, 16.875217]
[15.541933, 15.90976

[21.692259, 17.217728, 21.692259, 21.692259, 21.692259, 21.692259, 21.692259, 21.692259, 21.692259, 21.692259, 21.692259, 21.692259, 21.692259]
[25.246351, 22.430994, 25.246351, 25.246351, 25.246351, 25.246351, 25.246351, 25.246351, 25.246351, 25.246351, 25.246351, 25.246351, 25.246351]
[18.36244, 18.019016, 18.36244, 18.36244, 18.36244, 18.36244, 18.36244, 18.36244, 18.36244, 18.36244, 18.36244, 18.36244, 18.36244]
[20.114191, 18.322653, 20.114191, 20.114191, 20.114191, 20.114191, 20.114191, 20.114191, 20.114191, 20.114191, 20.114191, 20.114191, 20.114191]
[23.768164, 19.290836, 23.768164, 23.768164, 23.768164, 23.768164, 23.768164, 23.768164, 23.768164, 23.768164, 23.768164, 23.768164, 23.768164]
[21.847448, 19.069971, 21.847448, 21.847448, 21.847448, 21.847448, 21.847448, 21.847448, 21.847448, 21.847448, 21.847448, 21.847448, 21.847448]
[20.791029, 21.145384, 20.791029, 20.791029, 20.791029, 20.791029, 20.791029, 20.791029, 20.791029, 20.791029, 20.791029, 20.791029, 20.791029]
[22.

[20.148115, 19.942621, 20.148115, 20.148115, 20.148115, 20.148115, 20.148115, 20.148115, 20.148115, 20.148115, 20.148115, 20.148115, 20.148115]
[25.839186, 18.499971, 25.839186, 25.839186, 25.839186, 25.839186, 25.839186, 25.839186, 25.839186, 25.839186, 25.839186, 25.839186, 25.839186]
[25.06477, 18.98963, 25.06477, 25.06477, 25.06477, 25.06477, 25.06477, 25.06477, 25.06477, 25.06477, 25.06477, 25.06477, 25.06477]
[12.583752, 12.484522, 12.583752, 12.583752, 12.583752, 12.583752, 12.583752, 12.583752, 12.583752, 12.583752, 12.583752, 12.583752, 12.583752]
[18.204985, 15.830563, 18.204985, 18.204985, 18.204985, 18.204985, 18.204985, 18.204985, 18.204985, 18.204985, 18.204985, 18.204985, 18.204985]
[14.437294, 13.298339, 14.437294, 14.437294, 14.437294, 14.437294, 14.437294, 14.437294, 14.437294, 14.437294, 14.437294, 14.437294, 14.437294]
[14.514162, 13.272689, 14.514162, 14.514162, 14.514162, 14.514162, 14.514162, 14.514162, 14.514162, 14.514162, 14.514162, 14.514162, 14.514162]
[8.93

[13.082283, 11.890125, 13.082283, 13.082283, 13.082283, 13.082283, 13.082283, 13.082283, 13.082283, 13.082283, 13.082283, 13.082283, 13.082283]
[14.5356655, 13.120851, 14.5356655, 14.5356655, 14.5356655, 14.5356655, 14.5356655, 14.5356655, 14.5356655, 14.5356655, 14.5356655, 14.5356655, 14.5356655]
[16.441948, 13.861741, 16.441948, 16.441948, 16.441948, 16.441948, 16.441948, 16.441948, 16.441948, 16.441948, 16.441948, 16.441948, 16.441948]
[13.60979, 11.6245, 13.60979, 13.60979, 13.60979, 13.60979, 13.60979, 13.60979, 13.60979, 13.60979, 13.60979, 13.60979, 13.60979]
[15.63224, 14.375158, 15.63224, 15.63224, 15.63224, 15.63224, 15.63224, 15.63224, 15.63224, 15.63224, 15.63224, 15.63224, 15.63224]
[12.0660095, 11.262869, 12.0660095, 12.0660095, 12.0660095, 12.0660095, 12.0660095, 12.0660095, 12.0660095, 12.0660095, 12.0660095, 12.0660095, 12.0660095]
[25.639912, 20.419079, 25.639912, 25.639912, 25.639912, 25.639912, 25.639912, 25.639912, 25.639912, 25.639912, 25.639912, 25.639912, 25.63

[20.441381, 18.67848, 20.441381, 20.441381, 20.441381, 20.441381, 20.441381, 20.441381, 20.441381, 20.441381, 20.441381, 20.441381, 20.441381]
[18.28708, 17.60163, 18.28708, 18.28708, 18.28708, 18.28708, 18.28708, 18.28708, 18.28708, 18.28708, 18.28708, 18.28708, 18.28708]
[9.99721, 9.171494, 9.99721, 9.99721, 9.99721, 9.99721, 9.99721, 9.99721, 9.99721, 9.99721, 9.99721, 9.99721, 9.99721]
[14.780291, 14.407162, 14.780291, 14.780291, 14.780291, 14.780291, 14.780291, 14.780291, 14.780291, 14.780291, 14.780291, 14.780291, 14.780291]
[17.149988, 15.552059, 17.149988, 17.149988, 17.149988, 17.149988, 17.149988, 17.149988, 17.149988, 17.149988, 17.149988, 17.149988, 17.149988]
[13.33377, 13.674696, 13.33377, 13.33377, 13.33377, 13.33377, 13.33377, 13.33377, 13.33377, 13.33377, 13.33377, 13.33377, 13.33377]
[23.248493, 19.900866, 23.248493, 23.248493, 23.248493, 23.248493, 23.248493, 23.248493, 23.248493, 23.248493, 23.248493, 23.248493, 23.248493]
[26.541574, 24.648394, 26.541574, 26.541574

[22.100718, 21.271267, 22.100718, 22.100718, 22.100718, 22.100718, 22.100718, 22.100718, 22.100718, 22.100718, 22.100718, 22.100718, 22.100718]
[17.618975, 16.370878, 17.618975, 17.618975, 17.618975, 17.618975, 17.618975, 17.618975, 17.618975, 17.618975, 17.618975, 17.618975, 17.618975]
[14.899834, 12.525003, 14.899834, 14.899834, 14.899834, 14.899834, 14.899834, 14.899834, 14.899834, 14.899834, 14.899834, 14.899834, 14.899834]
[16.171572, 12.715633, 16.171572, 16.171572, 16.171572, 16.171572, 16.171572, 16.171572, 16.171572, 16.171572, 16.171572, 16.171572, 16.171572]
[19.384111, 18.844332, 19.384111, 19.384111, 19.384111, 19.384111, 19.384111, 19.384111, 19.384111, 19.384111, 19.384111, 19.384111, 19.384111]
[29.62932, 23.366247, 29.62932, 29.62932, 29.62932, 29.62932, 29.62932, 29.62932, 29.62932, 29.62932, 29.62932, 29.62932, 29.62932]
[11.608441, 10.090066, 11.608441, 11.608441, 11.608441, 11.608441, 11.608441, 11.608441, 11.608441, 11.608441, 11.608441, 11.608441, 11.608441]
[21.

[13.148722, 11.041665, 13.148722, 13.148722, 13.148722, 13.148722, 13.148722, 13.148722, 13.148722, 13.148722, 13.148722, 13.148722, 13.148722]
[22.482977, 18.623972, 22.482977, 22.482977, 22.482977, 22.482977, 22.482977, 22.482977, 22.482977, 22.482977, 22.482977, 22.482977, 22.482977]
[16.477505, 13.807027, 16.477505, 16.477505, 16.477505, 16.477505, 16.477505, 16.477505, 16.477505, 16.477505, 16.477505, 16.477505, 16.477505]
[13.407625, 14.756586, 13.407625, 13.407625, 13.407625, 13.407625, 13.407625, 13.407625, 13.407625, 13.407625, 13.407625, 13.407625, 13.407625]
[15.153386, 13.757289, 15.153386, 15.153386, 15.153386, 15.153386, 15.153386, 15.153386, 15.153386, 15.153386, 15.153386, 15.153386, 15.153386]
[12.354502, 10.737064, 12.354502, 12.354502, 12.354502, 12.354502, 12.354502, 12.354502, 12.354502, 12.354502, 12.354502, 12.354502, 12.354502]
[17.42457, 16.228031, 17.42457, 17.42457, 17.42457, 17.42457, 17.42457, 17.42457, 17.42457, 17.42457, 17.42457, 17.42457, 17.42457]
[16.

[21.658113, 21.209726, 21.658113, 21.658113, 21.658113, 21.658113, 21.658113, 21.658113, 21.658113, 21.658113, 21.658113, 21.658113, 21.658113]
