## Set Up


In [None]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5
import gin


## Set UP TPU Runtime

In [None]:
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "v3-8"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU zdetection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()
gin.parse_config_file(
        'gs://t5_training/t5-data/config/pretrained_models_google_base_operative_config.gin'
)
# gin.bind_parameter("SentencePieceVocabulary.extra_ids", 100)

# DEFAULT_OUTPUT_FEATURES = {
#     "inputs": Feature(
#         vocabulary=get_default_vocabulary(), add_eos=True, required=False),
#     "targets": Feature(vocabulary=get_default_vocabulary(), add_eos=True)
# }
vocab = "gs://t5_training/models/spm/t5_bio_spm_small.model"

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)


## Register NER Tasks

### NCBI NER Task

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/NCBI-disease/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/NCBI-disease/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
def ner_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["ncbi_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('ncbi_ner')
t5.data.TaskRegistry.add(
    "ncbi_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
                ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab)),
    # output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

### BC5CDR Chemical

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC5CDR-chem/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC5CDR-chem/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

def ner_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["bc5cdr_chem_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
t5.data.TaskRegistry.remove('bc5cdr_chem_ner')
t5.data.TaskRegistry.add(
    "bc5cdr_chem_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
               ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab)),

    # output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

### BC5CDR Disease

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC5CDR-disease/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC5CDR-disease/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

def ner_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["bc5cdr_disease_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
t5.data.TaskRegistry.remove('bc5cdr_disease_ner')
t5.data.TaskRegistry.add(
    "bc5cdr_disease_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
               ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab)),

    # output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

### BC2GM NER Task

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC2GM/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC2GM/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
def ner_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["bc2gm_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('bc2gm_ner')
t5.data.TaskRegistry.add(
    "bc2gm_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
               ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

###  BC4CHEMD NER Task

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC4CHEMD/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/BC4CHEMD/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

def ner_preprocessor(ds):
  def normalize_text(text):
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["bc4chemd_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
t5.data.TaskRegistry.remove('bc4chemd_ner')
t5.data.TaskRegistry.add(
    "bc4chemd_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
               ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

### JNLPBA NER Task


In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
             'gs://scifive/finetune/JNLPBA/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/JNLPBA/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

def ner_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["jnlpba_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
t5.data.TaskRegistry.remove('jnlpba_ner')
t5.data.TaskRegistry.add(
    "jnlpba_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
               ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

### LINNAEUS NER Task

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/linnaeus/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/linnaeus/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

def ner_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["linnaeus_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
t5.data.TaskRegistry.remove('linnaeus_ner')
t5.data.TaskRegistry.add(
    "linnaeus_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
               ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

### S800 NER

In [None]:
def dumping_dataset(split, shuffle_files = False):
    del shuffle_files
    if split == 'train':
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/s800/train.tsv_cleaned.tsv',
            ]
          )
    else:
      ds = tf.data.TextLineDataset(
            [
            'gs://scifive/finetune/s800/test.tsv_cleaned.tsv',
            ]
          )
    # Split each "<t1>\t<t2>" example into (input), target) tuple.
    ds = ds.map(
        functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                          field_delim="\t", use_quote_delim=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds

def ner_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    return text

  def to_inputs_and_targets(ex):
    """Map {"inputs": ..., "targets": ...}->{"inputs": ner..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["s800_ner: ", normalize_text(ex["input"])]),
        "targets": normalize_text(ex["target"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

print("A few raw validation examples...")
for ex in tfds.as_numpy(dumping_dataset("train").take(5)):
  print(ex)

In [None]:
t5.data.TaskRegistry.remove('s800_ner')
t5.data.TaskRegistry.add(
    "s800_ner",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=dumping_dataset,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[ner_preprocessor],
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy, 
               t5.evaluation.metrics.sequence_accuracy, 
               ],
    output_features=t5.data.Feature(vocabulary=t5.data.SentencePieceVocabulary(vocab))
)

## Mixtures

In [None]:
t5.data.MixtureRegistry.remove("ner_all")
t5.data.MixtureRegistry.add(
    "ner_all",
    [
     "ncbi_ner", 
     "bc5cdr_disease_ner", 
     "bc5cdr_chem_ner", 
     'bc4chemd_ner', 
     'bc2gm_ner', 
     'jnlpba_ner', 
     'linnaeus_ner', 
     's800_ner'
     ],
     default_rate=1.0
)

## Define Model

In [None]:


# Using pretrained_models from wiki + books
MODEL_SIZE = "base"
# BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
# BASE_PRETRAINED_DIR = "gs://t5_training/models/bio/pmc_v2"
BASE_PRETRAINED_DIR = "gs://t5_training/models/bio/pmc_v2"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
MODEL_DIR = "gs://t5_training/models/bio/ner_all_pmc_v2"
MODEL_DIR = os.path.join(MODEL_DIR, MODEL_SIZE)
# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128*2, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length = {'inputs': 256, 'targets': 256},
    learning_rate_schedule=0.001,
    save_checkpoints_steps=1000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)


In [None]:
vocabulary1 = t5.data.get_mixture_or_task("ner_all").get_vocabulary()
print(" ******** vocabulary1")
print(vocabulary1)
print(vocabulary1.vocab_size)


## Finetune 

In [None]:
if ON_CLOUD:
  %reload_ext tensorboard
  import tensorboard as tb
tb.notebook.start("--logdir " + MODEL_DIR)

In [None]:
FINETUNE_STEPS = 45000

model.finetune(
    mixture_or_task_name="ner_all",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

## Export Model

In [None]:
export_dir = os.path.join(MODEL_DIR, "export")

model.batch_size = 1 # make one prediction per call
saved_model_path = model.export(
    export_dir,
    checkpoint_step=-1,  # use most recent
    beam_size=1,  # no beam search
    temperature=1.0,  # sample according to predicted distribution
)
print("Model saved to:", saved_model_path)

## Load Saved Model

In [None]:
import tensorflow as tf

import tensorflow_text  # Required to run exported model.
print(tf.__version__)
def load_predict_fn(model_path):
  if tf.executing_eagerly():
    print("Loading SavedModel in eager mode.")
    imported = tf.saved_model.load(model_path, ["serve"])
    return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
  else:
    print("Loading SavedModel in tf 1.x graph mode.")
    tf.compat.v1.reset_default_graph()
    sess = tf.compat.v1.Session()
    meta_graph_def = tf.compat.v1.saved_model.load(sess, ["serve"], model_path)
    signature_def = meta_graph_def.signature_def["serving_default"]
    return lambda x: sess.run(
        fetches=signature_def.outputs["outputs"].name, 
        feed_dict={signature_def.inputs["input"].name: x}
    )

predict_fn = load_predict_fn(saved_model_path)

In [None]:
imported = tf.saved_model.load(model_path, ["serve"])
test = lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()

## Predict

In [None]:
tasks = [
         ['NCBI-disease', "ncbi_ner"], 
         ['BC5CDR-disease', "bc5cdr_disease_ner"], 
         ['BC5CDR-chem', "bc5cdr_chem_ner"], 
         ['BC4CHEMD', 'bc4chemd_ner'], 
         ['BC2GM', 'bc2gm_ner'], 
         ['JNLPBA', 'jnlpba_ner'], 
         ['linnaeus', 'linnaeus_ner'], 
         ['s800', 's800_ner']
         ]
output_dir = 'ner_all_pubmed_v2'

In [None]:
import tensorflow.compat.v1 as tf
# question_1 = "Emerin is a nuclear membrane protein which is missing or defective in Emery-Dreifuss muscular dystrophy (EDMD). It is one member of a family of lamina-associated proteins which includes LAP1, LAP2 and lamin B receptor (LBR). A panel of 16 monoclonal antibodies (mAbs) has been mapped to six specific sites throughout the emerin molecule using phage-displayed peptide libraries and has been used to localize emerin in human and rabbit heart. Several mAbs against different emerin epitopes did not recognize intercalated discs in the heart, though they recognized cardiomyocyte nuclei strongly, both at the rim and in intranuclear spots or channels. A polyclonal rabbit antiserum against emerin did recognize both nuclear membrane and intercalated discs but, after affinity purification against a pure-emerin band on a western blot, it stained only the nuclear membrane. These results would not be expected if immunostaining at intercalated discs were due to a product of the emerin gene and, therefore, cast some doubt upon the hypothesis that cardiac defects in EDMD are caused by absence of emerin from intercalated discs. Although emerin was abundant in the membranes of cardiomyocyte nuclei, it was absent from many non-myocyte cells in the heart. This distribution of emerin was similar to that of lamin A, a candidate gene for an autosomal form of EDMD. In contrast, lamin B1 was absent from cardiomyocyte nuclei, showing that lamin B1 is not essential for localization of emerin to the nuclear lamina. Lamin B1 is also almost completely absent from skeletal muscle nuclei. In EDMD, the additional absence of lamin B1 from heart and skeletal muscle nuclei which already lack emerin may offer an alternative explanation of why these tissues are particularly affected.." 
# question_2 = "Molecular analysis of the APC gene in 205 families: extended genotype-phenotype correlations in FAP and evidence for the role of APC amino acid changes in colorectal cancer predisposition." 
# question_3 = "Who are the 4 members of The Beatles?" 
# question_4 = "How many teeth do humans have?"

# questions = [question_2]


for t in tasks:
  dir = t[0]
  task = t[1]
  input_file = task + '_predict_input.txt'
  output_file = task + '_predict_output.txt'


  # Write out the supplied questions to text files.
  predict_inputs_path = os.path.join('gs://t5_training/t5-data/bio_data', dir, input_file)
  predict_outputs_path = os.path.join('gs://t5_training/t5-data/bio_data', dir, output_dir, MODEL_SIZE, output_file)
  # Manually apply preprocessing by prepending "triviaqa question:".

  # Ignore any logging so that we only see the model's answers to the questions.
  with tf_verbosity_level('ERROR'):
    model.batch_size = 8  # Min size for small model on v2-8 with parallelism 1.
    model.predict(
        input_file=predict_inputs_path,
        output_file=predict_outputs_path,
        # Select the most probable output token at each step.
        vocabulary=t5.data.SentencePieceVocabulary(vocab),
        temperature=0,
    )

  # The output filename will have the checkpoint appended so we glob to get 
  # the latest.
  prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
  print("Predicted task : " + task)
  print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
  # with tf.io.gfile.GFile(prediction_files[-1]) as f:
  #   for q, a in zip(questions, f):
  #     if q:
  #       print("Q: " + q)
  #       print("A: " + a)
  #       print()

## Scoring

In [None]:
!pip install seqeval
import nltk
from seqeval.metrics import f1_score, accuracy_score, classification_report, recall_score, precision_score
import re
import os
tasks = [
         ['NCBI-disease', "ncbi_ner"], 
         ['BC5CDR-disease', "bc5cdr_disease_ner"], 
         ['BC5CDR-chem', "bc5cdr_chem_ner"], 
         ['BC4CHEMD', 'bc4chemd_ner'], 
         ['BC2GM', 'bc2gm_ner'], 
         ['JNLPBA', 'jnlpba_ner'], 
         ['linnaeus', 'linnaeus_ner'], 
         ['s800', 's800_ner']
         ]

In [None]:
checkpoint = 249600
for t in tasks:
  !gsutil cp gs://t5_training/t5-data/bio_data/{t[0]}/{output_dir}/base/{t[1]}_predict_output.txt-* .
  # !gsutil cp gs://t5_training/t5-data/bio_data/{t[0]}/{t[1]}_predict_output.txt-* .
  # t5_training/t5-data/bio_data/BC4CHEMD/predicted_output_original_model/base
  !gsutil cp gs://t5_training/t5-data/bio_data/{t[0]}/{t[1]}_actual_output.txt .


In [None]:
def convert_BIO_labels(filename):
    result_labels = []
    with open(filename, 'r', encoding='utf-8') as file:
        cnt = 0
        for line in file:
            line = re.sub(r'\*(\w+)', r'\1*', line)
            tokens = re.sub(r'[!"#$%&\'()+,-.:;<=>?@[\\\]^_`{\|}~⁇]', ' ', line.strip()).split()
            seq_label = []
            start_entity = 0
            entity_type = 'O'
            for idx, token in enumerate(tokens):
                if token.endswith('*'):
                    start_entity += 1 if (start_entity == 0 or token[:-1] != entity_type) else -1
                    entity_type = token[:-1]
                else:
                    if start_entity == 0:
                        seq_label.append('O')
                        entity_type = 'O'
                    elif start_entity < 0:
                        raise "Something errors"
                    else:
                        if tokens[idx - 1].endswith('*'):
                            seq_label.append('B-' + entity_type.upper())
                        else:
                            seq_label.append('I-' + entity_type.upper())

            result_labels.append(seq_label)
            cnt += 1
#             if cnt % 100 == 0:
#                 print('Processed %d sentences' % cnt)
    return result_labels

In [None]:
# pred_file = 't5-data_bio_data_NCBI_NER_predict_outputs_1603446926.txt-1017500'
# actual_file = 'test_raw.txt'
# pred_file = 'data/ncbi/t5-data_bio_data_NCBI_NER_predict_outputs_1603446926.txt-1017500'
# actual_file = 'data/ncbi/test_raw.txt'
checkpoint = 237400
for task in tasks:
    d = task[0]
    t = task[1]
    
    pred_file = os.path.join(t + '_predict_output.txt-%d'%checkpoint)
    actual_file = os.path.join(t + '_actual_output.txt')
    
    # pred_file = 't5-data_bio_data_NCBI_NER_predict_outputs_1603446926.txt-1017500'
    # actual_file = 'test_raw.txt'
    pred_labels = convert_BIO_labels(pred_file)
    actual_labels = convert_BIO_labels(actual_file)
    ;print
    for i, (a, b) in enumerate(zip(pred_labels, actual_labels)):
        len_a = len(a)
        len_b = len(b)
        
        if len_a > len_b:
            pred_labels[i] = pred_labels[i][:len_b]
        elif len_a < len_b:
            pred_labels[i] = pred_labels[i] + ['PAD'] * (len_b - len_a)
            
    f1score = f1_score(actual_labels, pred_labels)
    recallscore = recall_score(actual_labels, pred_labels)
    precisionscore = precision_score(actual_labels, pred_labels)
    
#     f1score = f1_score(tmp_actual, tmp_pred)
#     recallscore = recall_score(tmp_actual, tmp_pred)
#     precisionscore = precision_score(tmp_actual, tmp_pred)
    
#     print("%s\t Precision: %2f \t Recall-score: %2f \t F1-score: %2f " % (t, precisionscore, recallscore, f1score))
#     print("Accuracy score: %2f" % accuracy_score(tmp_actual, tmp_pred))
    print(t, 'f1-score', f1score)
    print('recallscore', recallscore)
    print('precisionscore', precisionscore)
    # print("Report:", classification_report(actual_labels, pred_labels, digits=4))