In [1]:
from pprint import pprint
import pytorch_lightning as pl
import torch.nn as nn
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import hashlib
import torch
from transformers import (
    BertModel,
    BertTokenizer,
)
from typing import List
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
import json
import os
from tqdm.auto import tqdm

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


# Load model

In [38]:
from train_retriever_classification_new import *

encoder_question = BertEncoder(bert_question, max_question_len_global)
encoder_paragarph = BertEncoder(bert_paragraph, max_paragraph_len_global)
ret = Retriver(encoder_question, encoder_paragarph, tokenizer)

checkpoint_callback = ModelCheckpoint(
    filepath='out/{epoch}-{val_loss:.2f}-{val_acc:.2f}',
    save_top_k=10,
    verbose=True,
    monitor='val_acc',
    mode='max'
)

early_stopping = EarlyStopping('val_acc', mode='max')

trainer = pl.Trainer(
    gpus=0,
#     distributed_backend='dp',
    val_check_interval=0.01,
    min_epochs=1,
    checkpoint_callback=checkpoint_callback,
    early_stop_callback=early_stopping,
    gradient_clip_val=0.5
)
ret_trainee = RetriverTrainer(ret)

tmp = torch.load('out/crossentropy-epoch=0-val_loss=0.75-val_acc=0.72_v0.ckpt', map_location='cpu')

ret_trainee.load_state_dict(tmp['state_dict'])

<All keys matched successfully>

# Verify BERT Validation Performance

In [4]:
import json

In [6]:
def remove_html_toks(s):
    html_toks = [
        '<P>',
        '</P>',
        '<H1>',
        '</H1>',
        '</H2>',
        '</H2>',
    ]
    for i in html_toks:
        s = s.replace(i, '')
    return s

In [36]:
ret_trainee.retriever.refresh_cache()

In [40]:
hits = []
logits = []
c = 0
with open('natq/natq_clean.json', 'r') as f:
    for l in f:
        d = json.loads(l)
        if d['num_positives'] >= 1 and d['num_negatives'] >= 2:
            if d['dataset'] == 'dev':
                q = d['question']
                txts = [remove_html_toks(i) for i in d['right_paragraphs'][:1] + d['wrong_paragraphs'][:2]]
                r = ret_trainee.retriever.predict(q, txts)
                logits.append(r[1])
                if r[0][0] == txts[0]:
                    hits.append(1)
                else:
                    hits.append(0)
                c += 1
                if c % 9 == 1:
                    print(c, sum(hits) / len(hits))

Bert Encoder forwarded 1 times
Bert Encoder forwarded 1 times
1 1.0
10 0.8


KeyboardInterrupt: 

In [8]:
logits

[array([0.7310163 , 0.59958375, 0.31147915], dtype=float32),
 array([0.33012572, 0.27348018, 0.26916775], dtype=float32),
 array([0.72824657, 0.70778316, 0.2909686 ], dtype=float32),
 array([0.7219803 , 0.7213584 , 0.44483003], dtype=float32),
 array([0.7310514 , 0.7294375 , 0.29001316], dtype=float32),
 array([0.31494823, 0.26924643, 0.2692249 ], dtype=float32),
 array([0.68954176, 0.55632776, 0.30350024], dtype=float32),
 array([0.73104435, 0.54280925, 0.2690589 ], dtype=float32),
 array([0.5894281 , 0.27045855, 0.2691992 ], dtype=float32),
 array([0.730893 , 0.7306549, 0.6863563], dtype=float32),
 array([0.7175499 , 0.35694364, 0.3417997 ], dtype=float32),
 array([0.49935776, 0.2694974 , 0.26932088], dtype=float32),
 array([0.7081304 , 0.68516076, 0.6615101 ], dtype=float32),
 array([0.7292225 , 0.61326706, 0.5391778 ], dtype=float32),
 array([0.73102885, 0.6198809 , 0.6189771 ], dtype=float32),
 array([0.7306189 , 0.46878347, 0.29780522], dtype=float32),
 array([0.73051345, 0.60892

In [9]:
sum(hits) / len(hits)

0.7647058823529411

In [10]:
q = d['question']
txts = [remove_html_toks(i) for i in d['right_paragraphs'][:1] + d['wrong_paragraphs'][:2]]

In [11]:
q

'the chemical brothers brothers gonna work it out songs'

In [12]:
txts

['<Table> <Tr> <Th> No . </Th> <Th> Title </Th> <Th> Writer ( s ) </Th> <Th> Producer ( s ) </Th> <Th> Length </Th> </Tr> <Tr> <Td> 1 . </Td> <Td> `` Brother \'s Gonna Work It Out \'\' ( performed by Willie Hutch ) </Td> <Td> Hutch </Td> <Td> Hutch </Td> <Td> 4 : 00 </Td> </Tr> <Tr> <Td> 2 . </Td> <Td> `` Not Another Drugstore \'\' ( Planet Nine Mix ) ( performed by Chemical Brothers featuring Justin Warfield ) </Td> <Td> <Ul> <Li> Tom Rowlands </Li> <Li> Ed Simons </Li> <Li> Gianni Garofalo </Li> <Li> Warfield </Li> </Ul> </Td> <Td> The Chemical Brothers </Td> <Td> 3 : 25 </Td> </Tr> <Tr> <Td> 3 . </Td> <Td> `` Block Rockin \' Beats \'\' ( The Micronauts Mix ) </Td> <Td> <Ul> <Li> Tom Rowlands </Li> <Li> Ed Simons </Li> <Li> Jesse Weaver </Li> </Ul> </Td> <Td> <Ul> <Li> The Chemical Brothers </Li> <Li> Christophe Monier </Li> <Li> George Issakidis </Li> </Ul> </Td> <Td> 3 : 26 </Td> </Tr> <Tr> <Td> 4 . </Td> <Td> `` This Ai n\'t Chicago \'\' ( performed by On The House ) </Td> <Td> Cu

In [13]:
ret_trainee.retriever.predict(q, txts)

(['<Table> <Tr> <Th> No . </Th> <Th> Title </Th> <Th> Writer ( s ) </Th> <Th> Producer ( s ) </Th> <Th> Length </Th> </Tr> <Tr> <Td> 1 . </Td> <Td> `` Brother \'s Gonna Work It Out \'\' ( performed by Willie Hutch ) </Td> <Td> Hutch </Td> <Td> Hutch </Td> <Td> 4 : 00 </Td> </Tr> <Tr> <Td> 2 . </Td> <Td> `` Not Another Drugstore \'\' ( Planet Nine Mix ) ( performed by Chemical Brothers featuring Justin Warfield ) </Td> <Td> <Ul> <Li> Tom Rowlands </Li> <Li> Ed Simons </Li> <Li> Gianni Garofalo </Li> <Li> Warfield </Li> </Ul> </Td> <Td> The Chemical Brothers </Td> <Td> 3 : 25 </Td> </Tr> <Tr> <Td> 3 . </Td> <Td> `` Block Rockin \' Beats \'\' ( The Micronauts Mix ) </Td> <Td> <Ul> <Li> Tom Rowlands </Li> <Li> Ed Simons </Li> <Li> Jesse Weaver </Li> </Ul> </Td> <Td> <Ul> <Li> The Chemical Brothers </Li> <Li> Christophe Monier </Li> <Li> George Issakidis </Li> </Ul> </Td> <Td> 3 : 26 </Td> </Tr> <Tr> <Td> 4 . </Td> <Td> `` This Ai n\'t Chicago \'\' ( performed by On The House ) </Td> <Td> C

# Verify dataset

In [14]:
import torch

In [15]:
train_tensor = torch.load('natq/natq_train.pt')
val_tensor = torch.load('natq/natq_dev.pt')

In [16]:
## Berts
model_str = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_str)


In [17]:
rev_vocab = {tokenizer.vocab[i]: i for i in tokenizer.vocab}

In [19]:
q_tok_ids, _, _, para_tok_ids, _, _ = train_tensor[0]

In [20]:
def toks2text(toks):
    return ' '.join([rev_vocab[i] for i in toks.numpy()])

In [21]:
toks2text(q_tok_ids)

'[CLS] what is an advantage of daily vertical migration by marine zoo ##pl ##an ##kt ##on quiz ##let [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [22]:
toks2text(para_tok_ids[0])

'[CLS] die ##l vertical migration ( d ##v ##m ) , also known as di ##urn ##al vertical migration , is a pattern of movement used by some organisms , such as cope ##pods , living in the ocean and in lakes . the migration occurs when organisms move up to the ep ##ipe ##lag ##ic zone at night and return to the me ##sop ##ela ##gic zone of the oceans or to the h ##yp ##oli ##m ##nio ##n zone of lakes during the day . the word die ##l comes from the latin dies day , and means a 24 - hour period . it is the greatest migration in the world in terms of biomass . this form of migration is not restricted to any one taxa as examples are known from crust ##ace ##ans ( cope ##pods ) , mo ##llus ##cs ( squid ) , and ray - finn ##ed fishes ( trout ) . there are various stimuli responsible for this phenomenon , the most prominent being response to changes in light intensity , though evidence suggests that biological clock s are an underlying stimulus as well . there are a number of potential reasons t

In [23]:
hits = []
logits = []
c = 0
with open('natq/natq_clean.json', 'r') as f:
    for l in f:
        d = json.loads(l)
        if d['num_positives'] >= 1 and d['num_negatives'] >= 2:
            if d['dataset'] == 'dev':
                q = d['question']
                txts = [remove_html_toks(i) for i in d['right_paragraphs'][:1] + d['wrong_paragraphs'][:2]]
                q_tok_ids, _, _, para_tok_ids, _, _ = val_tensor[c]
                print(q, '\n', toks2text(q_tok_ids))
                c += 1
                print('-'*50)
                print(txts[0], '\n', toks2text(para_tok_ids[0]))
                print('-'*50)
        if c == 1:
            break

when did computer become widespread in homes and schools 
 [CLS] when did computer become widespread in homes and schools [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
--------------------------------------------------
 Home computers were a class of microcomputers entering the market in 1977 , and becoming common during the 1980s . They were marketed to consumers as affordable and accessible computers that , for the first time , were intended for the use of a single nontechnical user . These computers were a distinct market segment that typically cost much less than business , scientific or engineering - oriented computers of the time such as the IBM PC , and were generally less powerful in terms of memory and expandability . However , a home computer often had better graphics and sound than contemporary business computers . Their most common uses were playing video games , but they were also regularly used for 

# Verify simple examples

In [28]:
q='What is symptom of Covid?'
s1='Covid can cause fever and cough.'
s2='Covid is caused by a virus.'
s3='The world is going crazy.'

In [24]:
q= 'Do you love me?'
s1= 'I love you.'
s2= 'You hate me.'
s3= 'It aint much but it is honesdt work.'

In [29]:
ret_trainee.retriever.predict(q, [s1, s2, s3])

(['The world is going crazy.',
  'Covid can cause fever and cough.',
  'Covid is caused by a virus.'],
 array([0.6281982 , 0.35796586, 0.30579755], dtype=float32),
 array([2, 0, 1]))

# Verify changing bert encoe with vanillar ones

In [39]:
bert_q = BertModel.from_pretrained('bert-base-uncased')
bert_para = BertModel.from_pretrained('bert-base-uncased')

ret_trainee.retriever.bert_question_encoder.bert.load = bert_q

ret_trainee.retriever.bert_paragraph_encoder.bert = bert_para

In [50]:
ret_trainee.retriever.bert_question_encoder.bert.load_state_dict(bert_q.state_dict)

AttributeError: 'function' object has no attribute 'copy'

# Verify Covid results

In [41]:
from elasticsearch import Elasticsearch

In [42]:
es = Elasticsearch(
        [{"host": "es-covidfaq.dev.dialoguecorp.com", "port": 443}],
        use_ssl=True,
        verify_certs=True,
    )
if not es.ping():
    raise ValueError(
        "Connection failed, please start server at localhost:9200 (default)"
    )

In [43]:
def search_section_index(es, index, query, topk):
    res = es.search(
        {
            "query": {
                "multi_match": {"query": query, "fields": ["section", "content"],}
            },
            "size": topk,
        },
        index=index,
    )
    return res

In [45]:
secindex = "en-covid-section-index"
topk_sec = 30

In [46]:
qs = [
    'What is Covid-19?',
    'What are the symptoms of Covid 19? ',
    'How does Covid-19 spread?',
    'When should I go to the hospital?',
    'How many cases in Montreal?',
    'What should I do if I have fever?',
    'What is the incubation period for COVID-19?',
    'How can I protect myself from the covid-19?',
    'How can I make the difference between a cold and covid19?',
    'Where can I get tested?',
    'Is it true that warm kills Coronavirus?',
]

In [47]:
for q in qs:

    rs = search_section_index(es, secindex, q, topk_sec)["hits"]["hits"]
    print('question: ', q, '\n')
    txts = [' '.join(i['_source']['content']) + '\n'*5 for i in rs]
    print('Elastic Search: ')
    pprint(txts[:3])
    
    rs_rerank = ret_trainee.retriever.predict(q, txts)
    print('=' * 50, '\n')
    print('BERT: ')
    pprint(rs_rerank[0][:3])
    pprint(rs_rerank[1])
    print('-' * 100)

question:  What is Covid-19? 

Elastic Search: 
['COVID-19 usually infects the nose, throat and lungs. In most cases, it is '
 'spread by: close contact with an infected person when the person coughs or '
 'sneezes;touching infected surfaces with your hands and then touching your '
 'mouth, nose or eyes.\n'
 '\n'
 '\n'
 '\n'
 '\n',
 'There is no specific treatment or vaccine for COVID-19 for the moment. Most '
 'people with the virus will recover on their own\xa0by remaining at home '
 'without needing to go to the hospital. \n'
 '\n'
 '\n'
 '\n'
 '\n',
 'The main symptoms of COVID-19 are as follows: FeverCoughDifficulty breathing '
 'The symptoms can be mild (similar to a cold) or more severe (such as those '
 'associated with pneumonia and respiratory or kidney failure).\n'
 '\n'
 '\n'
 '\n'
 '\n']

BERT: 
['The main symptoms of COVID-19 are as follows: FeverCoughDifficulty breathing '
 'The symptoms can be mild (similar to a cold) or more severe (such as those '
 'associated with pn


BERT: 
['The main symptoms of COVID-19 are as follows: FeverCoughDifficulty breathing '
 'The symptoms can be mild (similar to a cold) or more severe (such as those '
 'associated with pneumonia and respiratory or kidney failure).\n'
 '\n'
 '\n'
 '\n'
 '\n',
 'COVID-19 is a disease caused by a coronavirus, a highly contagious virus '
 'that affects the respiratory tract. It is transmitted from one person to '
 'another. A pandemic occurs when a new virus spreads throughout the world. '
 'Since humans are not protected against the new virus, a greater number of '
 'people become sick.\n'
 '\n'
 '\n'
 '\n'
 '\n',
 'Visits Starting March\xa014,\xa02020, non-essential\xa0visits to hospitals, '
 'residential and long-term car centres, intermediate resources, targeted '
 'family-type resources\xa0and private senior’s residence will be prohibited '
 'to protect the most vulnerable individuals and workers in the health and '
 'social services network. Lockdown Since March 23, 2020, a lockdown


BERT: 
['Starting March 13, 2020, non-essential\xa0visits to hospitals, residential '
 'and long-term care centres, intermediate resources, targeted family-type '
 'resources and private seniors’ homes will be prohibited to protect the most '
 'vulnerable individuals and workers in the health and social services '
 'network. In the case of births, the other parent, significant other or '
 'natural caregiver identified for this purpose is not deemed to be a visitor '
 'and can accompany the mother. In cases where the end of life is imminent, '
 'the constant presence of one or two people who are of significant importance '
 'to the person is not deemed to be a visit. \n'
 '\n'
 '\n'
 '\n'
 '\n',
 'Contrary to fear, which is a response to a well-defined and very real '
 'threat, anxiety is a response to a vague or unknown threat. Anxiety '
 'manifests itself when we believe that a dangerous or unfortunate event may '
 'take place and are expecting it. Everyone experiences anxiety at the


BERT: 
['The main symptoms of COVID-19 are as follows: FeverCoughDifficulty breathing '
 'The symptoms can be mild (similar to a cold) or more severe (such as those '
 'associated with pneumonia and respiratory or kidney failure).\n'
 '\n'
 '\n'
 '\n'
 '\n',
 'The procedures are outlined in the <a '
 'href="/en/family-and-support-for-individuals/emergency-daycare-services/">Emergency '
 'Childcare Services section</a>\n'
 '\n'
 '\n'
 '\n'
 '\n',
 'COVID-19 is a disease caused by a coronavirus, a highly contagious virus '
 'that affects the respiratory tract. It is transmitted from one person to '
 'another. A pandemic occurs when a new virus spreads throughout the world. '
 'Since humans are not protected against the new virus, a greater number of '
 'people become sick.\n'
 '\n'
 '\n'
 '\n'
 '\n']
array([0.72472215, 0.706569  , 0.6897979 , 0.6742099 , 0.609067  ,
       0.60280806, 0.5589649 , 0.5499372 , 0.48876572, 0.3794523 ,
       0.3707552 , 0.36034146, 0.3509054 , 0.33732924, 


BERT: 
['If your business conducts an activity that is not listed above, but you '
 'consider it essential, you can verify the designation as an essential '
 'business. If you have any questions, fill out this online form and an '
 'information officer of the Gouvernement du Québec will contact you '
 'shortly.\xa0\xa0 Complete the form\xa0\n'
 '\n'
 '\n'
 '\n'
 '\n',
 'Fever is one of the body’s defence mechanisms that help fight infection. '
 'Acetaminophen is recommended to reduce fever and make you more comfortable '
 'unless your health professional advises against it or you are allergic. '
 'Fever is defined as follows: • Children: 38 °C (100.4 °F) and more (rectal) '
 '• Adults: 38 °C (100.4 °F) and more (oral) • Seniors: 37.8 °C (98.6 °F) and '
 'more (oral) OR • 1.1 °C higher than normal\n'
 '\n'
 '\n'
 '\n'
 '\n',
 'The coronavirus COVID-19 is very easily transmitted by tiny droplets that '
 'are expelled into the air when an infected person coughs or sneezes. If the '
 'per