In [1]:
import tensorflow as tf
import tensorflow_text 
import numpy as np
from tqdm import tqdm
import pandas as pd
from tensorflow_io.bigquery import BigQueryClient
client = BigQueryClient()

## defining dataset

In [2]:
GCP_PROJECT_ID = 'for-antoine'
DATASET_GCP_PROJECT_ID = GCP_PROJECT_ID # A copy of the data is saved in the user project
DATASET_ID = 'MESH_CLASSIFICATION'
TRAIN_TABLE_ID = 'jama_no_keywords'

FEATURES = ['id','title','abstract']
DTYPES=[tf.string] * len(FEATURES)

In [3]:
query= f"""
   SELECT *
   FROM `for-antoine.MESH_CLASSIFICATION.jama_no_keywords`
"""

In [4]:
BATCH_SIZE = 64


def read_session(TABLE_ID):
    return client.read_session(
        "projects/" + GCP_PROJECT_ID, DATASET_GCP_PROJECT_ID, TABLE_ID, DATASET_ID,
        FEATURES , DTYPES, requested_streams=2
)

def extract_question(input_dict):
    features = dict(input_dict)
    id = features['id']
    question = 'bioasq organism: ' + 'journal: jama title: ' + features["title"] +  ' abstract: ' + tf.strings.lower(features["abstract"])
    return (id,question)


raw_train_data = read_session(TRAIN_TABLE_ID).parallel_read_rows().map(extract_question).batch(BATCH_SIZE)


## defining models

In [12]:
def model_export(batch_size, model_name, model_size, model_version, model_number, prefix, candidates):
    if batch_size != BATCH_SIZE:
        print('ERROR: error in batch size')
    return  {'batch_size': batch_size, 'model_name':model_name, 'model_size':model_size, 'model_version':model_version, 'model_number':model_number, 'prefix': prefix, 'candidates':candidates}



models = []
batch_size = 64
model_name = 'BIOASQ-organism'
model_size = 'base'
model_version = 'general'
model_number = '1592512738'
prefix = 'bioasq organism: '
candidates = ['Antibodies','Monoclonal','Receptors','Recombinant Proteins','Amino Acid Sequence','Cells','DNA', 'Viral', 'Genes','Antineoplastic Agents', 'Binding Sites', 'Antigens', 'Tumor Cells', 'Neoplasms', 'Neurons', 'Liver', 'Brain', 'Escherichia coli', 'Tumor', 'Muscle', 'Kidney', 'Chromosomes']

model = model_export(batch_size, model_name, model_size, model_version, model_number, prefix, candidates)
models.append(model)

batch_size = 64
model_name = 'BIOASQ-techniques'
model_size = 'base'
model_version = 'general'
model_number = '1592657921'
prefix = 'bioasq techniques: '
candidates = ['Prognosis','Retrospective Studies','Treatment Outcome','Disease Models','Cultured','Cloning','Electrophoresis','Molecular Sequence Data','Polymerase Chain Reaction','Severity of Illness Index','Injections']
model = model_export(batch_size, model_name, model_size, model_version, model_number, prefix, candidates)
models.append(model)

batch_size = 64
model_name = 'BIOASQ-categories'
model_size = 'base'
model_version = 'general'
model_number = '1592659655'
prefix = 'bioasq general: '
candidates = ['Animals', 'Male', 'Rabbits', 'Adult', 'Humans', 'Dogs', 'Cattle','Animal', 'Female', 'Mice', 'Human', 'Infant', 'Adolescent', 'Middle Aged', 'Young Adult', 'Age Factors', 'Aged', '80 and over', 'Rats', 'Child', 'Preschool', 'Sprague-Dawley','Swine','Newborn']


model = model_export(batch_size, model_name, model_size, model_version, model_number, prefix, candidates)
models.append(model)

## running inference

In [13]:
def prediction(pairs):
    """
    Returns the final prediction of test data in a list
            Parameters:
            list of pair for the predict function 

            Returns:
            list of final prediction  
    """
    final_predictions=[]
    predicted = predict_fn(pairs)
    for vals in predicted:
        vals = vals.decode('utf-8')
        final_predictions.append(vals)
    return final_predictions

def generate_keywords(predictions, candidates):
    results = []
    for prediction in predictions:
        result = ''
        keywords = []
        prediction = prediction.split(' ')
        for idx in range(len(prediction)):
            if 'true' in prediction[idx]:
                keywords.append(candidates[idx])
        result = ', '.join(keywords)
        results.append(result)
                
    return results

        
        

In [None]:
nb_rows = 34307
nb_rows = 128

for model in models:
    
    prefix = model['prefix']
    def extract_question(input_dict):
        features = dict(input_dict)
        id = features['id']
        question = prefix + 'journal: jama title: ' + features["title"] +  ' abstract: ' + tf.strings.lower(features["abstract"])
        return (id,question)

    raw_train_data = read_session(TRAIN_TABLE_ID).parallel_read_rows().map(extract_question).batch(BATCH_SIZE)
    model_name = model['model_name']
    model_size = model['model_size']
    model_version = model['model_version']
    model_number = model['model_number']
    saved_model_path = f'gs://antoine-vs-t5/{model_name}/{model_size}/{model_version}/export/{model_number}' #large

    def load_predict_fn(model_path):
        imported = tf.saved_model.load(saved_model_path, ["serve"])
        return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()


    predict_fn = load_predict_fn(saved_model_path)
    
    predictions = np.array([])
    ids = np.array([])
    #for _ in tqdm(range(int(nb_rows / batch_size // 1))):
    for batch in tqdm(raw_train_data.take(2)):
        ids_batch = batch[0].numpy().astype('str')
        ids = np.concatenate((ids, ids_batch), axis=0)
        predictions_batch = generate_keywords(prediction(batch[1].numpy()), model['candidates'])
        predictions = np.concatenate((predictions, predictions_batch), axis=0)
    df = pd.DataFrame([ids, predictions]).T
    df.columns = ['ids','predictions']
    df.to_csv(f'predictions-{TRAIN_TABLE_ID}-{model_name}.csv', index=False)


2it [00:21, 10.71s/it]


In [26]:
df.sort_values(by='ids')

Unnamed: 0,ids,predictions
148,00511dc5-c278-473f-985b-f75b17299057,Humans
294,00b4b3cc-492e-4e32-b33d-b1455d184197,Humans
297,00e3a07d-ac64-4b53-b5c1-9483c8d5762a,"Humans, Infant, Newborn"
63,010833af-d2f0-4ce0-a758-0fbe1c491ec5,"Male, Adult, Humans, Child"
239,026df008-8110-4375-b701-82588c84a9d9,Humans
...,...,...
305,f812729f-c164-40a4-836e-6493f819b974,"Humans, Aged"
19,f97dcbf4-0cb1-40b6-8730-32aecfe2d844,Humans
125,fcb48f2f-815a-4e41-bc93-aaa2d8b85be9,"Animals, Humans"
231,fd0c451e-5e31-4bee-90ad-fcc6139037b7,"Humans, Infant"
