In [17]:
import pickle
import random
from tqdm import tqdm
import copy
from nltk.translate.bleu_score import sentence_bleu
import matplotlib.pyplot as plt
from sentence_splitter import SentenceSplitter, split_text_into_sentences
import re
import numpy as np
from chat_scripts.deberta_qa import DebertaQA
import pandas as pd
from random import shuffle


In [2]:
import torch
import json
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
#from datasets import load_dataset, load_metric
import pickle

## Data construction

Gold answers are constructed using sentence_bleu scores

In [3]:
'''
Class for preproc dialogue data generated by OpenAI API
'''
class MyCorpus():
    def __init__(self, min_tresh=0.05, norm_tresh=0.75, deberta_max_tokens=500, deberta_min_tokens_diff=200, deberta_sents_around=0): # constructor
        # hyper-params
        self.windows_shift = 2
        self.min_tresh = min_tresh
        self.norm_tresh = norm_tresh
        self.deberta_max_tokens = deberta_max_tokens
        self.deberta_min_tokens_diff = deberta_min_tokens_diff
        self.deberta_sents_around = deberta_sents_around
       
        self.splitter = SentenceSplitter(language='en')
        #self.deberta_qa = DebertaQA(device='cuda')
        
    def reinit(self, text):
        self.text = re.compile(r"\s+").sub(" ", text)
        self.sents_ = self.splitter.split(text)
        self.sents = []  
        self.words = []  
        self.bleu_windows = []
        self.groundings = []
        self.groundings_detailed = []
        self.bleu_scores_debug = []
        self.deberta_grounding = None
        self.deberta_span = None
        
        self.groundings_none = []  # answers with no grounding: 'yes, this is correct', bleu=0
        self.groundings_not_found = []  # answers with bleu below tresholds
        
        self._process_corpus()
    
    def _process_corpus(self):
        offset = 0
        word_id_max = 0
        
        for sent_id, sent in enumerate(self.sents_):
            sent_start = self.text.find(sent, offset)
            sent_end = sent_start + len(sent)
            
            if sent_start==-1:
                print(f'sent_id={sent_id}, sent=|{sent}|, offset={offset}')
                return -1
            
            assert self.text[sent_start:sent_end] == sent, f'Error:{sent_id}, {sent_start}, {sent_end}, {offset}'
            #print(f'Found #{sent_id}, {sent_start}, {sent_end}, {offset}')
            
            offset = sent_end
            
            self.sents.append((sent_id, sent, sent_start, sent_end))
            
            sent_split = sent.split()
            curr_words = [(wid + word_id_max, w, sent_id) for wid, w in enumerate(sent_split)]
            self.words.extend(curr_words)
            word_id_max += len(sent_split)
            
    def find_grounding(self, answer, answer_id):
        '''
        detect_sents: defines how grounding candidates convert to sentences
            'intersect' : find sents intersection with grouding
            'window_num_sents' : get number of grounding sents ~ number of sents in input window (answer)
        '''
        self.bleu_windows = []
        answer_tokens = answer.split()
        window_len = len(answer_tokens)
        window_num_sents = len(self.splitter.split(answer))
        max_score = -1
        
        for i in range(0, len(self.words)-window_len, self.windows_shift):
            # print(i)
            win_start, win_end = i, i + window_len
            win_tokens = [w[1] for w in self.words[win_start:win_end]]
            win_text = ' '.join(win_tokens)
            bleu_score = sentence_bleu([win_tokens], answer_tokens, weights=(0.5, 0.5, 0., 0.))
            bleu_score_filtered = -1.0 if bleu_score <= self.min_tresh else bleu_score # apply min theshold
            
            if bleu_score > max_score:
                max_score = bleu_score
            
            # bleu_scores - windows: [win_start, win_end, score, score_filtered, score_norm, answer, window_text]
            self.bleu_windows.append([win_start, win_end, bleu_score, bleu_score_filtered, None, answer, win_text])
        
       
        #assert max_score != 0, f'max bleu score is equial {max_score}!' 
        if max_score == 0:
            self.groundings_none.append(answer_id)
            return
        
        # normalize scores
        for w in self.bleu_windows:
            w[4] = w[3]/max_score if w[3]>0 else -1.0
            
        # apply max norm treshold
        self.grounding_candidates = [w for w in self.bleu_windows if w[4] >= self.norm_tresh]
        
        if not self.grounding_candidates:
            # print(f'No grounding candidates found')
            self.groundings_not_found.append(answer_id)
            return

        # Convert grounding candidates to output sentences
        grounding_sents_= []

        for g in self.grounding_candidates:
            sent_start = self.words[g[0]][2]
            sent_end = self.words[g[1]][2] # each is (wid, w, sentid)
            grounding_sents_.append([sent_start, sent_end, g[4]])
        
        grounding_sents_.sort(key=lambda x: x[2], reverse=True)
        self.debug_grounding_sents = grounding_sents_
        
        # Get number of grounding sents ~ number of sents in input window (answer)
        gss = set()
        
        
        for gs in grounding_sents_:
            start_sent, end_sent = gs[0], gs[1]
            gss.update(list(range(gs[0], gs[1]+1)))  # add all sent from range [start, end] (not pythonic, all included)
            
            if len(gss) > window_num_sents:
                break
                
        # set of continuous groundings 
        self.groundings = self.split_to_intervals(sorted(gss))
        
        for g in self.groundings:
            self.groundings_detailed.append([self.sents[sid] for sid in g])
         
    
    def deberta_extract_grounding(self, question):
        # Extract DEBERTA grounding
        grounding, span = self.deberta_qa.extract_grounding(question=question, 
                                                            context=self.text,
                                                            max_tokens=self.deberta_max_tokens, 
                                                            min_tokens_diff=self.deberta_min_tokens_diff, 
                                                            return_response_span=True, 
                                                            sents_around=self.deberta_sents_around)
        
        self.deberta_grounding = grounding 
        self.deberta_span = span
        
    @staticmethod
    def split_to_intervals(s):
        '''
        Utility function for splitting set of integers (grounding sents ids) into intervals of consecutive numbers (id)
        '''
        ds = np.array(s[1:]) - np.array(s[:-1])
        grounding_starts = [0] + [i+1 for i in list(np.where(ds>1)[0])]

        grounding_intervals = [] # indexes
        for idx in range(len(grounding_starts)-1):
            grounding_intervals.append((
                 grounding_starts[idx],
                 grounding_starts[idx+1]
            ))
        grounding_intervals.append((grounding_starts[-1], len(s)))  # last interval

        # print(f's={s}\n grounding_starts: {grounding_starts}, grounding_intervals: {grounding_intervals}')

        return [s[gi[0]:gi[1]] for gi in grounding_intervals]


In [19]:
def clean_dataset(data):
    '''
    For Davinchi only!
    '''
    bad = []

    for id_, i in enumerate(data):
        ts = i['parsed_dialogue']['turns']
        if len(ts)%2==1:
            bad.append(id_)
            continue

        t1 = [t['speaker'] for t in ts[::2]] # person
        t2 = [t['speaker'] for t in ts[1::2]] # bot

        if set(t1) != {'person'} or set(t2) != {'bot'}:
            bad.append(id_)

    data_clean = [data[i] for i in range(len(data)) if i not in bad]
    print(f'Total: {len(data)}, bad: {len(bad)}, clean: {len(data_clean)}')
    
    return data_clean



def generate_squad_dataset(data, params):
    '''
    generates GT for bot answers and returns list of dict:
    
        template = {
            'id': None, 
            'title': None,
            'context': None,
            'question': None,
            'answers': {
                'text': [], 
                'answer_start': []
            }
        }
    '''
    
    mc = MyCorpus(**params)
    res=[]
    
    for item_id, item in tqdm(enumerate(data)):
        #print(f'start with {item_id}')
        mc.reinit(item['text'])
        
        d_id = str(item['meta_segments'][0]['id'])
        d_title = str(item['meta_segments'][0]['title'])
        d_text = mc.text

        ts = item['parsed_dialogue']['turns']
 
        for person_turn in range(0, len(ts), 2):
            mc.reinit(item['text'])
            tmp = {}        
            qa_id = f'{person_turn}_{person_turn+1}'
            q = ts[person_turn]['text']
            a = ts[person_turn+1]['text']

            mc.find_grounding(a, f'{item_id}_{qa_id}')  # by bleu

            tmp['id'] = f'{d_id}_{qa_id}'
            tmp['title'] = d_title
            tmp['context'] = d_text
            tmp['question'] = q # person
            tmp['chat_gpt_answer'] = a # GT
            
            # tmp['answers'] - dict from grounding 
            answers_dict = {'text': [], 'answer_start': []}
            
            for g in mc.groundings_detailed:
                answers_dict['text'].append(' '.join([s[1] for s in g]))  # joined sents
                answers_dict['answer_start'].append(g[0][2])  # start of 1st sent
            
            tmp['answers'] = answers_dict
            
            res.append(tmp)
            
    return res
    

In [35]:
with open('../_common/papers_segmented_data/davinci_dialogues_full_postproc.pkl', 'rb') as f:
    davinchi_data_gross = pickle.load(f)

with open('../_common/papers_segmented_data/chatgpt_dialogues_full_postproc_upd.pkl', 'rb') as f:
    chatgpt_data = pickle.load(f)    

In [36]:
davinchi_data = clean_dataset(davinchi_data_gross)

Total: 3588, bad: 771, clean: 2817


In [37]:
# from grid search
best_params = {'min_tresh': 0.15,
               'norm_tresh': 0.8,
               'deberta_max_tokens': 500,
               'deberta_min_tokens_diff': 200,
               'deberta_sents_around': 1}

valid_size = 400

In [38]:
dataset = davinchi_data + chatgpt_data
shuffle(dataset)
train, valid = dataset[:-valid_size], dataset[-valid_size:]
len(train), len(valid)

(11204, 400)

In [39]:
train_squad = generate_squad_dataset(train, best_params)

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
11204it [35:35,  5.25it/s]


In [40]:
with open('../_common/datasets/deberta_retrain/squad_format_train_upd.pkl', 'wb') as f:
    pickle.dump(train_squad, f)

In [41]:
valid_squad = generate_squad_dataset(valid, best_params)

with open('../_common/datasets/deberta_retrain/squad_format_valid_upd.pkl', 'wb') as f:
    pickle.dump(valid_squad, f)

400it [01:19,  5.03it/s]


In [718]:
# we need json lines format for hugging face Dataset loader!
import json

with open('datasets/deberta_retrain/squad_format_train.json', 'w') as outfile:
    for entry in train_squad:
        json.dump(entry, outfile)
        outfile.write('\n')
        
with open('datasets/deberta_retrain/squad_format_valid.json', 'w') as outfile:
    for entry in valid_squad:
        json.dump(entry, outfile)
        outfile.write('\n')



### Train with hard negatives

In [5]:
import json
import pickle

In [6]:
with open('/home/jovyan/chatbot/_common/datasets/deberta_retrain/squad_format_train_withhn_filt_upd.pkl', 'rb') as f:
    train_with_hard = pickle.load(f)

with open('/home/jovyan/chatbot/_common/datasets/deberta_retrain/squad_format_valid_withhn_filt_upd.pkl', 'rb') as f:
    valid_with_hard = pickle.load(f)    

In [7]:
len(train_with_hard), len(valid_with_hard)

(57306, 2039)

In [6]:
# Fix None in titles and key error: answers_start -> answer_start
def fix_entry(entry):
    entry['title'] = '' if type(entry['title']) != str else entry['title']
    if 'answers_start' in entry['answers']:
        entry['answers']['answer_start'] = entry['answers']['answers_start']
        del entry['answers']['answers_start']
        
    return entry
    

with open('/home/jovyan/chatbot/_common/datasets/deberta_retrain/squad_format_train_withhn_filt_upd.json', 'w') as outfile:
    for entry in train_with_hard:
        json.dump(fix_entry(entry), outfile)
        outfile.write('\n')
        
with open('/home/jovyan/chatbot/_common/datasets/deberta_retrain/squad_format_valid_withhn_filt_upd.json', 'w') as outfile:
    for entry in valid_with_hard:
        json.dump(fix_entry(entry), outfile)
        outfile.write('\n')


# Deberta train
https://github.com/huggingface/transformers/tree/v4.26.0/examples/pytorch/question-answering

model: deepset/deberta-v3-base-squad2

Note that if your dataset contains samples with no possible answers (like SQuAD version 2), you need to pass along the flag --version_2_with_negative.

In [34]:
!WANDB_DISABLED=True CUDA_VISIBLE_DEVICES=0 /home/user/conda/envs/deberta_retrain/bin/python deberta_qa/run_qa.py \
  --model_name_or_path deepset/deberta-v3-base-squad2 \
  --version_2_with_negative \
  --train_file '/home/jovyan/chatbot/_common/datasets/deberta_retrain/squad_format_train_withhn_filt_upd_use.json' \
  --do_train \
  --validation_file '/home/jovyan/chatbot/_common/datasets/deberta_retrain/squad_format_valid_withhn_filt_upd_use.json' \
  --do_eval \
  --per_device_train_batch_size=4 \
  --per_device_eval_batch_size=4 \
  --gradient_accumulation_steps=2 \
  --learning_rate 2e-5 \
  --num_train_epochs 4 \
  --save_steps 20000 \
  --doc_stride 128 \
  --output_dir="/home/jovyan/chatbot/_common/checkpoint/deberta_retrain_withhn_filt_upd_use" \
  --overwrite_output_dir \
  --max_seq_length 384 

/home/user/conda/envs/deberta_retrain
/home/user/conda/envs/deberta_retrain
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
04/24/2023 18:35:12 - INFO - __main__ - Training/evaluation parameters TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=No

Running tokenizer on train dataset:   0%|      | 0/57306 [00:00<?, ? examples/s]04/24/2023 18:35:44 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/jovyan/.cache/huggingface/datasets/json/default-9cc1d0e13347b8e4/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-e1a2c8f4247762c6.arrow
Running tokenizer on validation dataset:   0%|  | 0/2039 [00:00<?, ? examples/s]04/24/2023 18:36:42 - INFO - datasets.arrow_dataset - Caching processed dataset at /home/jovyan/.cache/huggingface/datasets/json/default-9cc1d0e13347b8e4/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-05af7f3dbf197559.arrow
[INFO|trainer.py:1650] 2023-04-24 18:37:09,964 >> ***** Running training *****
[INFO|trainer.py:1651] 2023-04-24 18:37:09,964 >>   Num examples = 164635
[INFO|trainer.py:1652] 2023-04-24 18:37:09,964 >>   Num Epochs = 4
[INFO|trainer.py:1653] 2023-04-24 18:37:09,964 >>   Instantaneous batch size per device = 4
[INFO|trainer.py:165

{'loss': 0.3676, 'learning_rate': 2.62792166771952e-06, 'epoch': 3.47}          
{'loss': 0.3619, 'learning_rate': 2.5064386024588173e-06, 'epoch': 3.5}         
{'loss': 0.3268, 'learning_rate': 2.3849555371981147e-06, 'epoch': 3.52}        
{'loss': 0.3241, 'learning_rate': 2.263472471937412e-06, 'epoch': 3.55}         
{'loss': 0.3289, 'learning_rate': 2.1419894066767093e-06, 'epoch': 3.57}        
{'loss': 0.3571, 'learning_rate': 2.0205063414160067e-06, 'epoch': 3.6}         
{'loss': 0.325, 'learning_rate': 1.899023276155304e-06, 'epoch': 3.62}          
{'loss': 0.3497, 'learning_rate': 1.7775402108946014e-06, 'epoch': 3.64}        
{'loss': 0.3434, 'learning_rate': 1.6560571456338987e-06, 'epoch': 3.67}        
{'loss': 0.3315, 'learning_rate': 1.5345740803731962e-06, 'epoch': 3.69}        
{'loss': 0.3632, 'learning_rate': 1.4130910151124936e-06, 'epoch': 3.72}        
{'loss': 0.3459, 'learning_rate': 1.291607949851791e-06, 'epoch': 3.74}         
{'loss': 0.3509, 'learning_r

## Manual eval

In [35]:
import pickle
import random
from tqdm import tqdm
import copy
from nltk.translate.bleu_score import sentence_bleu
import matplotlib.pyplot as plt
#from sentence_splitter import SentenceSplitter, split_text_into_sentences
import re
import numpy as np
from chat_scripts.deberta_qa import DebertaQA
import pandas as pd
from random import shuffle


  from .autonotebook import tqdm as notebook_tqdm


In [111]:
deberta_qa = DebertaQA(model_name='deepset/deberta-v3-base-squad2')

In [36]:
deberta_qa = DebertaQA(model_name='/home/jovyan/chatbot/_common/checkpoint/deberta_retrain_withhn_filt_upd_use/')

In [49]:
context = '''As shown in Table 2, our model outperforms baseline models and other lexicon-based models on four Chinese NER datasets. Our model outperforms TENER (Yan et al., 2019) by 1.72 in average F1 score. For lattice LSTM, our model has an average F1 improvement of 1.51 over it. When using another lexicon (Li et al., 2018), our model also outperforms CGN by 0.73 in average F1 score. Maybe due to the characteristic of Transformer, the improvement of FLAT over other lexicon-based models on small datasets is not so significant like that on large datasets. To verify the computation efficiency of our model, we compare the inference-speed of different lexicon-based models on Ontonotes. The result is shown in Figure 3. GNN-based models outperform lattice LSTM and LR-CNN. But the RNN encoder of GNN-based models also degrades their speed. Because our model has no recurrent module and can fully leverage parallel computation of GPU, it outperforms other methods in running efficiency. In terms of leveraging batch-parallelism, the speedup ratio brought by batch-parallelism is 4.97 for FLAT, 2.1 for lattice LSTM, when batch size = 16. Due to the simplicity of our model, it can benefit from batch-parallelism more significantly.'''

In [44]:
context = '''The flat-lattice structure consists of spans with different lengths. To encode the interactions among spans, we propose the relative position encoding of spans. For two spans x i and x j in the lattice, there are three kinds of relations between them: intersection, inclusion and separation, determined by their heads and tails. Instead of directly encoding these three kinds of relations, we use a dense vector to model their relations. It is calculated by continuous transformation of the head and tail information. Thus, we think it can not only represent the relation between two tokens, but also indicate more detailed information, such as the distance between a character and a word.'''

In [52]:
question = 'What is the speedup ratio brought by batch-parallelism?'

In [53]:
deberta_qa.predict(question, context, max_tokens=384, min_tokens_diff=64)

{'score': 0.9848097562789917,
 'start_pos': 978,
 'end_pos': 1129,
 'text': ' In terms of leveraging batch-parallelism, the speedup ratio brought by batch-parallelism is 4.97 for FLAT, 2.1 for lattice LSTM, when batch size = 16.'}