<a href="https://colab.research.google.com/github/google-research/tapas/blob/master/notebooks/sqa_predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2020 The Google AI Language Team Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2019 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Running a Tapas fine-tuned checkpoint
---
This notebook shows how to load and make predictions with TAPAS model, which was introduced in the paper: [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349)

# Clone and install the repository


First, let's install the code.

In [1]:
! pip install tapas-table-parsing

Collecting tapas-table-parsing
  Downloading tapas_table_parsing-0.0.1.dev0-py3-none-any.whl (195 kB)
[?25l[K     |█▊                              | 10 kB 22.3 MB/s eta 0:00:01[K     |███▍                            | 20 kB 28.7 MB/s eta 0:00:01[K     |█████                           | 30 kB 16.4 MB/s eta 0:00:01[K     |██████▊                         | 40 kB 11.4 MB/s eta 0:00:01[K     |████████▍                       | 51 kB 5.7 MB/s eta 0:00:01[K     |██████████                      | 61 kB 6.7 MB/s eta 0:00:01[K     |███████████▊                    | 71 kB 7.3 MB/s eta 0:00:01[K     |█████████████▍                  | 81 kB 5.6 MB/s eta 0:00:01[K     |███████████████                 | 92 kB 6.2 MB/s eta 0:00:01[K     |████████████████▊               | 102 kB 6.8 MB/s eta 0:00:01[K     |██████████████████▍             | 112 kB 6.8 MB/s eta 0:00:01[K     |████████████████████            | 122 kB 6.8 MB/s eta 0:00:01[K     |█████████████████████▉          | 1

# Fetch models fom Google Storage

Next we can get pretrained checkpoint from Google Storage. For the sake of speed, this is base sized model trained on [SQA](https://www.microsoft.com/en-us/download/details.aspx?id=54253). Note that best results in the paper were obtained with a large model, with 24 layers instead of 12.

In [1]:
! gsutil cp gs://tapas_models/2020_04_21/tapas_sqa_base.zip . && unzip tapas_sqa_base.zip

Copying gs://tapas_models/2020_04_21/tapas_sqa_base.zip...
| [1 files][  1.0 GiB/  1.0 GiB]   51.4 MiB/s                                   
Operation completed over 1 objects/1.0 GiB.                                      
Archive:  tapas_sqa_base.zip
replace tapas_sqa_base/model.ckpt.data-00000-of-00001? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: tapas_sqa_base/model.ckpt.data-00000-of-00001  y
y

  inflating: tapas_sqa_base/model.ckpt.index  
  inflating: tapas_sqa_base/README.txt  
  inflating: tapas_sqa_base/vocab.txt  
  inflating: tapas_sqa_base/bert_config.json  
  inflating: tapas_sqa_base/model.ckpt.meta  


# Imports

In [2]:
import tensorflow.compat.v1 as tf
import os 
import shutil
import csv
import pandas as pd
import IPython

tf.get_logger().setLevel('ERROR')

In [3]:
from tapas.utils import tf_example_utils
from tapas.protos import interaction_pb2
from tapas.utils import number_annotation_utils
from tapas.scripts import prediction_utils

# Load checkpoint for prediction

Here's the prediction code, which will create and `interaction_pb2.Interaction` protobuf object, which is the datastructure we use to store examples, and then call the prediction script.

In [4]:
os.makedirs('results/sqa/tf_examples', exist_ok=True)
os.makedirs('results/sqa/model', exist_ok=True)
with open('results/sqa/model/checkpoint', 'w') as f:
  f.write('model_checkpoint_path: "model.ckpt-0"')
for suffix in ['.data-00000-of-00001', '.index', '.meta']:
  shutil.copyfile(f'tapas_sqa_base/model.ckpt{suffix}', f'results/sqa/model/model.ckpt-0{suffix}')

In [5]:
max_seq_length = 512
vocab_file = "tapas_sqa_base/vocab.txt"
config = tf_example_utils.ClassifierConversionConfig(
    vocab_file=vocab_file,
    max_seq_length=max_seq_length,
    max_column_id=max_seq_length,
    max_row_id=max_seq_length,
    strip_column_names=False,
    add_aggregation_candidates=False,
)
converter = tf_example_utils.ToClassifierTensorflowExample(config)

def convert_interactions_to_examples(tables_and_queries):
  """Calls Tapas converter to convert interaction to example."""
  for idx, (table, queries) in enumerate(tables_and_queries):
    interaction = interaction_pb2.Interaction()
    for position, query in enumerate(queries):
      question = interaction.questions.add()
      question.original_text = query
      question.id = f"{idx}-0_{position}"
    for header in table[0]:
      interaction.table.columns.add().text = header
    for line in table[1:]:
      row = interaction.table.rows.add()
      for cell in line:
        row.cells.add().text = cell
    number_annotation_utils.add_numeric_values(interaction)
    for i in range(len(interaction.questions)):
      try:
        yield converter.convert(interaction, i)
      except ValueError as e:
        print(f"Can't convert interaction: {interaction.id} error: {e}")
        
def write_tf_example(filename, examples):
  with tf.io.TFRecordWriter(filename) as writer:
    for example in examples:
      writer.write(example.SerializeToString())

def predict(table_data, queries):
  table = [list(map(lambda s: s.strip(), row.split("|"))) 
           for row in table_data.split("\n") if row.strip()]
  examples = convert_interactions_to_examples([(table, queries)])
  write_tf_example("results/sqa/tf_examples/test.tfrecord", examples)
  write_tf_example("results/sqa/tf_examples/random-split-1-dev.tfrecord", [])
  
  ! python -m tapas.run_task_main \
    --task="SQA" \
    --output_dir="results" \
    --noloop_predict \
    --test_batch_size={len(queries)} \
    --tapas_verbosity="ERROR" \
    --compression_type= \
    --init_checkpoint="tapas_sqa_base/model.ckpt" \
    --bert_config_file="tapas_sqa_base/bert_config.json" \
    --mode="predict" 2> error


  results_path = "results/sqa/model/test_sequence.tsv"
  all_coordinates = []
  df = pd.DataFrame(table[1:], columns=table[0])
  display(IPython.display.HTML(df.to_html(index=False)))
  print()
  with open(results_path) as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t')
    for row in reader:
      coordinates = prediction_utils.parse_coordinates(row["answer_coordinates"])
      all_coordinates.append(coordinates)
      answers = ', '.join([table[row + 1][col] for row, col in coordinates])
      position = int(row['position'])
      print(">", queries[position])
      print(answers)
  return all_coordinates

# Predict

In [7]:
# Example nu-1000-0
result = predict("""
Doctor_ID|Doctor_Name|Department|opd_day|Morning_time|Evening_time
1|ABCD|Nephrology|Monday|9|5
2|ABC|Opthomology|Tuesday|9|6
3|DEF|Nephrology|Wednesday|9|6
4|GHI|Gynaecology|Thursday|9|6
5|JKL|Orthopeadics|Friday|9|6
6|MNO|Cardiology|Saturday|9|6
7|PQR|Dentistry|Sunday|9|5
8|STU|Epidemology|Monday|9|6
9|WVX|ENT|Tuesday|9|5
10|GILOY|Genetics|Wednesday|9|6
11|Rajeev|Neurology|Wednesday|10|4:30
12|Makan|Immunology|Tuesday|9|4:30
13|Arora|Paediatrics|Sunday|11|4:30
14|Piyush|Radiology|Monday|11:20|2
15|Roha|Gynaecology|Wednesday|9:20|2
16|Bohra|Dentistry|Thursday|11|2
17|Rajeev Khan|Virology|Tuesday|10|2
18|Arnab|Pharmocology|Sunday|10|2
19|Muskan|ENT|Friday|10|2
20|pamela|Epidemology|Monday|10|2
21|Rohit|Radiology|Tuesday|10|2
22|Aniket|Cardiology|Saturday|10|2
23|Darbar|Genetics|Saturday|10|2
24|Suyash|Neurology|Friday|10|2
25|Abhishek|Immunology|Wednesday|10|2
26|Yogesh|Immunology|Saturday|10|2
27|Kunal|Paediatrics|Monday|10|2
28|Vimal|Pharmocology|Friday|10|2
29|Kalyan|Virology|Tuesday|10|2
30|DSS|Nephrology|Thursday|10|2

""", ["How many doctors are there in Immunology department?", "of these, which doctor is available on Saturday?"])

is_built_with_cuda: True
is_gpu_available: False
GPUs: []
Training or predicting ...
Evaluation finished after training step 0.


Doctor_ID,Doctor_Name,Department,opd_day,Morning_time,Evening_time
1,ABCD,Nephrology,Monday,9,5
2,ABC,Opthomology,Tuesday,9,6
3,DEF,Nephrology,Wednesday,9,6
4,GHI,Gynaecology,Thursday,9,6
5,JKL,Orthopeadics,Friday,9,6
6,MNO,Cardiology,Saturday,9,6
7,PQR,Dentistry,Sunday,9,5
8,STU,Epidemology,Monday,9,6
9,WVX,ENT,Tuesday,9,5
10,GILOY,Genetics,Wednesday,9,6



> How many doctors are there in Immunology department?
12, 26, 25
> of these, which doctor is available on Saturday?
Yogesh
