<a href="https://colab.research.google.com/github/MWaser/C2-API-Server/blob/master/CT_Embed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vector Embedding & Search/Retrieval

First half of notebook is for creating files to load embedding data into Snowflake

Second half is for testing semantic retrieval after the embedding data is loaded

Both could be done from within SnowFlake if we have a non-Government region
but <br />creating large files in Snowflake is EXPENSIVE since each encode call takes 1/10 of a second.

# Initialization

The following needs to be run at the start of every session whether creating embedding files or testing semantic retrieval

In [None]:
!pip install snowflake-connector-python
import snowflake.connector
!pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
import numpy
import pandas as pd
import time

In [None]:
from google.colab import drive
drive.mount('/drive')

Mounted at /drive


# Create Embedding Data File(s) for loading into Snowflake

There are currently 144K clinical trials that are either recruiting or awaiting recruiting.  
These have been placed in the C2_CTRIAL.CLINICAL_TRIALS.CT_SEARCH data set.

Snowflake has a size limit of 16MB when loading JSON files which only allows about 3K records/batch.

In [None]:
rowCount = 145000
batchSize = 3000
filename = 'embed'

In [None]:
# modify the query and the marked variable assignment block as appropriate
startRow = 1
endRow = batchSize
while startRow <= rowCount:
    fname = '/drive/My Drive/' + filename + str(startRow) + '.json'
    print(time.strftime("%H:%M:%S", time.localtime()) + '  ' + fname)
    cnx = snowflake.connector.connect(user='waserma', password='snowH34d',
            account='csi_eval.us-east-1-gov.aws',
            warehouse='WAREHOUSE1' )
    curs = cnx.cursor()
    query = 'SELECT * FROM C2_CTRIAL.CLINICAL_TRIALS.CT_SEARCH WHERE ROWNUM >= ' + str(startRow) + ' AND ROWNUM <= ' + str(endRow) + ' ORDER BY ROWNUM'
    snowData = curs.execute(query).fetchall()
    column_names = [i[0] for i in curs.description]
    curs.close()
    tempDF = pd.DataFrame(snowData, columns=column_names)
    # BEGIN MODIFY
    tempDF['BRIEF_TITLE'] = tempDF['BRIEF_TITLE'].astype('str')
    tempDF['OFFICIAL_TITLE'] = tempDF['OFFICIAL_TITLE'].astype('str')
    tempDF['DESCRIPTION'] = tempDF['DESCRIPTION'].astype('str')
    tempDF['DOWNCASE_NAMES'] = tempDF['DOWNCASE_NAMES'].astype('str')
    tempDF['DOWNCASE_MESH_TERMS'] = tempDF['DOWNCASE_MESH_TERMS'].astype('str')
    tempDF['EMBED'] = tempDF.apply(lambda row: model.encode(row['BRIEF_TITLE'] + row['OFFICIAL_TITLE'] + row['DESCRIPTION'] + row['DOWNCASE_NAMES'] + row['DOWNCASE_MESH_TERMS']), axis=1)
    tempDF.drop(['ROWNUM', 'BRIEF_TITLE', 'OFFICIAL_TITLE', 'DESCRIPTION', 'DOWNCASE_NAMES', 'DOWNCASE_MESH_TERMS'], axis=1, inplace=True)
    # END MODIFY
    tempDF.to_json(fname, orient="records")
    startRow = startRow + batchSize
    endRow = endRow + batchSize
print("done!")


19:09:45  /drive/My Drive/embed1.json
done!


# Test Semantic Search/Retrieval

Waiting to complete Sonflake side . . .

In [None]:
searchStr = input('What is your medical malfunction?  ')
searchVec = model.encode(searchStr)
finalStr = numpy.array_repr(searchVec).replace("\n","").replace("array(", "").replace(",      dtype=float32)","").replace("       ", " ").replace("  ", " ")
print(finalStr)
# print(time.strftime("%H:%M:%S", time.localtime()) + '  start search')
# print(searchVec)
cnx = snowflake.connector.connect(user='waserma', password='snowH34d',
        account='csi_eval.us-east-1-gov.aws',
        warehouse='WAREHOUSE1' )
curs = cnx.cursor()
query = 'SELECT * FROM C2_CTRIAL.CLINICAL_TRIALS.CT_SEARCH WHERE ROWNUM >= 1 AND ROWNUM <= 10 ORDER BY ROWNUM'
snowData = curs.execute(query).fetchall()
column_names = [i[0] for i in curs.description]
curs.close()
tempDF = pd.DataFrame(snowData, columns=column_names)
tempDF.head()
# print(time.strftime("%H:%M:%S", time.localtime()) + '  done')
