In [7]:
# IMPORTANT! ADJUST PARAMETERS ACCORDING TO MODEL

_max_input = 1024
# byte models : 1024, word models : 512
_max_output = 128
# byte models : 128, word models : 16
splt = "validation"
# "validation" or "test"
asr_system = "wav2vec2-large-960h-lv60-self"
# "wav2vec2-large-960h-lv60-self" or "wav2vec2-large-10min-lv60-self"

In [2]:
def add_eos_to_examples(example):
    example['input_text'] = 'question: %s  context: %s </s>'% (example['question_asr'].lower(), example['context_asr'].lower())
    example['target_text'] = '%s </s>' % example['answers']['text'][0]
    
    return example

# tokenize the examples
def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['input_text'], padding='max_length', truncation=True, max_length=_max_input)
    target_encodings = tokenizer.batch_encode_plus(example_batch['target_text'], padding='max_length', truncation=True, max_length=_max_output)

    encodings = {
        'input_ids': input_encodings['input_ids'], 
        'attention_mask': input_encodings['attention_mask'],
        'target_ids': target_encodings['input_ids'],
        'target_attention_mask': target_encodings['attention_mask']
    }

    return encodings

In [3]:
## SQuAD evaluation script. Modifed slightly for this notebook

from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    #print(prediction,ground_truth)
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    #print(prediction,ground_truth,prediction_tokens,ground_truth_tokens,num_same)
    #s()
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    
    if precision > 1.0 or recall > 1.0:
      print(prediction,ground_truth)
      s()
    
    f1 = (2 * precision * recall) / (precision + recall)
    #print(f1)
    return f1


def exact_match_score(prediction, ground_truth):
    # since prediction != ground truth, use time span
    if normalize_answer(prediction) == normalize_answer(ground_truth):
      return 1
    else:
      return 0
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(gold_answers, predictions,references_time = None,predictions_time = None):
    f1 = exact_match = total = 0
    tot_ff1 = tot_aos = 0
    idx = 0
    for ground_truths, prediction in zip(gold_answers, predictions):
      #print(ground_truths,prediction)
      is_exact_match = metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
      
      
      #print("after exact",ground_truths,prediction)
      #s()
      f1 += metric_max_over_ground_truths(
          f1_score, prediction, ground_truths)
      #print(f1/total)
      
      if references_time and predictions_time:
        if is_exact_match:
          maxff1 = 1
          maxAOS = 1
        else:
          refs = references_time[total]
          preds = predictions_time[total]
          maxAOS = 0
          maxff1 = 0
          maxprecision = 0
          maxrecall = 0
          for r in refs:
            for p in preds:
              prev_start, ans_start, ans_end, next_end = p
              #print(ans_start,ans_end,r[0],r[1])
              overlap = [max(r[0],ans_start),min(r[1],ans_end)]
              overlap_delta = max((overlap[1] - overlap[0]),0)
              union = [min(r[0],ans_start),max(r[1],ans_end)]
              union_delta = union[1] - union[0]
              #print(overlap_delta,union_delta)
              AOS = overlap_delta/union_delta
              maxAOS = max(AOS,maxAOS)

              x_delta = ans_end - ans_start
              y_delta = r[1] - r[0]
              precision = overlap_delta/x_delta
              recall = overlap_delta/y_delta
              
              maxprecision = max(maxprecision,precision)
              maxrecall = max(maxrecall,recall)

              # check shift_right_AOS
              #print(ans_start,next_end,r[0],r[1])
              overlap = [max(r[0],ans_start),min(r[1],next_end)]
              overlap_delta = max((overlap[1] - overlap[0]),0)
              union = [min(r[0],ans_start),max(r[1],next_end)]
              union_delta = union[1] - union[0]
              #print(overlap_delta,union_delta)
              shiftright_AOS = overlap_delta/union_delta
              
              # check shift_left_AOS
              #print(prev_start,ans_end,r[0],r[1])
              overlap = [max(r[0],prev_start),min(r[1],ans_end)]
              overlap_delta = max((overlap[1] - overlap[0]),0)
              union = [min(r[0],prev_start),max(r[1],ans_end)]
              union_delta = union[1] - union[0]
              shiftleft_AOS = overlap_delta/union_delta
              #print(AOS,shiftright_AOS, shiftleft_AOS)
              
              if AOS > shiftright_AOS and AOS > shiftleft_AOS:
                pass
                #AOS = 1
                #maxAOS = max(AOS,maxAOS)
                #maxprecision = 1
                #maxrecall = 1
                #maxff1 = 1
                #is_exact_match = 1
                break
            if is_exact_match:
              break
              

            try:
              maxff1 = 2*maxprecision*maxrecall/(maxprecision+maxrecall)
            except:
              maxff1 = 0
        tot_ff1 += maxff1
        tot_aos += maxAOS
      exact_match += is_exact_match
      total += 1
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    aos = 100.0*tot_aos/total
    ff1 = 100.0 * tot_ff1 / total

    return {'exact_match': exact_match, 'f1': f1, "ff1":ff1, "aos":aos}

In [None]:
!git clone https://github.com/huggingface/transformers.git
!pip install -q ./transformers
!pip install -q -U nlp
!pip install -q tpubar sentencepiece
!pip install datasets

In [4]:
import torch
import datasets
from transformers import T5ForConditionalGeneration, ByT5Tokenizer,T5Tokenizer
from tqdm.auto import tqdm

In [6]:
#model = T5ForConditionalGeneration.from_pretrained('Splend1dchan/t5-small-squad').to('cuda')
#model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-squad').to('cuda')
#model = T5ForConditionalGeneration.from_pretrained('Splend1dchan/t5-large-squad').to('cuda')

#model = T5ForConditionalGeneration.from_pretrained('Splend1dchan/byt5small-squad1024-from6000steps').to('cuda')
#model = T5ForConditionalGeneration.from_pretrained('Splend1dchan/byt5-base-squad').to('cuda') 

model = T5ForConditionalGeneration.from_pretrained('Splend1dchan/t5lephone-small-textsquad').to('cuda') # ByT5Tokenizer

#tokenizer = T5Tokenizer.from_pretrained('t5-small')
tokenizer = ByT5Tokenizer.from_pretrained('google/byt5-small')

Downloading:   0%|          | 0.00/816 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

In [8]:
valid_dataset = datasets.load_dataset(f'Splend1dchan/NMSQA_{asr_system}', split=splt)
valid_dataset = valid_dataset.map(add_eos_to_examples, load_from_cache_file=False)
valid_dataset = valid_dataset.map(convert_to_features, batched=True, load_from_cache_file=False)


# set the tensor type and the columns which the dataset should return
columns = ['input_ids', 'target_ids', 'attention_mask', 'target_attention_mask']
valid_dataset.set_format(type='torch', columns=columns)
dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=2)


Using custom data configuration Splend1dchan--NMSQA_wav2vec2-large-960h-lv60-self-331c7554e52616ed


Downloading and preparing dataset parquet/Splend1dchan--NMSQA_wav2vec2-large-960h-lv60-self to /home/splend1d/.cache/huggingface/datasets/parquet/Splend1dchan--NMSQA_wav2vec2-large-960h-lv60-self-331c7554e52616ed/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/668k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/8.07M [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Dataset parquet downloaded and prepared to /home/splend1d/.cache/huggingface/datasets/parquet/Splend1dchan--NMSQA_wav2vec2-large-960h-lv60-self-331c7554e52616ed/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901. Subsequent calls will reuse this data.


0ex [00:00, ?ex/s]

  0%|          | 0/11 [00:00<?, ?ba/s]

  f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"


# Generate Text Answers

In [9]:
answers = []
for n,batch in tqdm(enumerate(dataloader), total = len(dataloader)):
  outs = model.generate(input_ids=batch['input_ids'].cuda(), 
                        attention_mask=batch['attention_mask'].cuda(),
                        max_length=_max_output,
                        )
  outs = [tokenizer.decode(ids) for ids in outs]
  answers.extend(outs)
  if n == 0:
    print(outs)
  

  0%|          | 0/5247 [00:00<?, ?it/s]

['<pad>denver broncos </s>', '<pad>denver broncos </s>']


# Find timespan from text answers

In [12]:
valid_word_dataset = datasets.load_dataset(f'Splend1dchan/NMSQA_{asr_system}', split=splt)

Using custom data configuration Splend1dchan--NMSQA_wav2vec2-large-960h-lv60-self-331c7554e52616ed
Reusing dataset parquet (/home/splend1d/.cache/huggingface/datasets/parquet/Splend1dchan--NMSQA_wav2vec2-large-960h-lv60-self-331c7554e52616ed/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


In [13]:
def edit_distance(string1, string2):
    """Ref: https://bit.ly/2Pf4a6Z"""
    if len(string2) < len(string1):
      string1, string2 = string2, string1
    if len(string1) > len(string2):
        difference = len(string1) - len(string2)
        string1[:difference]

    elif len(string2) > len(string1):
        difference = len(string2) - len(string1)
        string2[:difference]

    else:
        difference = 0

    for i in range(len(string1)):
        if string1[i] != string2[i]:
            difference += 1

    return difference

In [14]:
predictions = []
predictions_times = []
references = []
references_times = []
not_extractive = 0
for ref, pred in tqdm(zip(valid_word_dataset, answers),total = len(valid_word_dataset)):
  context_wrd_ls = ref["context_asr"].lower().split()
  question_wrd_ls = ref["question_asr"].lower().split()
  pred_wrd_ls = pred.replace("<pad>","").replace("</s>","").split()
  
  l = len(pred_wrd_ls)
  found = False
  predictions_time = []
  found = False
  for i in range(len(context_wrd_ls)):
    if context_wrd_ls[i:i+l] == pred_wrd_ls[:]:
      start_idx = i
      end_idx = i + l - 1
      timespan = [ref["context_times"][start_idx][0],ref["context_times"][end_idx][1]]
      if start_idx != 0:
        timespan = [ref["context_times"][start_idx-1][0]] + timespan
      else:
        timespan = [ref["context_times"][start_idx][0]] + timespan
      
      if end_idx != len(ref["context_times"]) - 1:
        timespan =  timespan + [ref["context_times"][end_idx+1][0]]
      else:
        timespan = timespan + [ref["context_times"][end_idx][0]]
      predictions_time.append(timespan)
      found = True
  if not found:
    #print(context_wrd_ls,pred_wrd_ls)
    min_e = len(context_wrd_ls) * 2
    for i in range(len(context_wrd_ls)-l):
      context =  " ".join(context_wrd_ls[i:i+l])
      pred = " ".join(pred_wrd_ls[:])
      e = edit_distance(context,pred)
      if e < min_e:
        min_e = e
        start_idx = i
        end_idx = i + l - 1
    #print(start_idx,end_idx)
    try:
      timespan = [ref["context_times"][start_idx][0],ref["context_times"][end_idx][1]]
      if start_idx != 0:
        timespan = [ref["context_times"][start_idx-1][0]] + timespan
      else:
        timespan = [ref["context_times"][start_idx][0]] + timespan
      
      if end_idx != len(ref["context_times"]) - 1:
        timespan =  timespan + [ref["context_times"][end_idx+1][0]]
      else:
        timespan = timespan + [ref["context_times"][end_idx][0]]
      predictions_time.append(timespan)
    except:
      predictions_time.append([0,0,10,10])


    not_extractive += 1
  predictions_times.append(predictions_time)
  references_time = []
  for s, e in zip(ref["answers"]["audio_full_answer_start"],ref["answers"]["audio_full_answer_end"]):
    references_time.append([s,e])
  references_times.append(references_time)
  #break
  predictions.append(pred.replace("<pad>","").replace("</s>",""))
  references.append(ref["answers"]["text"])
  #references_time.append(ref["answers"]["text"][0])
print("no ans:",not_extractive)

  0%|          | 0/10493 [00:00<?, ?it/s]

no ans: 656


In [15]:
for i in range(0,32):
  print(predictions[i], references[i],predictions_times[i],references_times[i])

denver broncos  ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'] [[9.896, 10.336, 11.116, 11.196]] [[10.29, 11.15]]
denver broncos  ['Carolina Panthers', 'Carolina Panthers', 'Carolina Panthers'] [[9.896, 10.336, 11.116, 11.196]] [[14.64, 15.559999999999999]]
levi stadium in the san francisco bay area at santa clara california  ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."] [[21.296, 21.416, 25.156, 25.496]] [[24.02403628117914, 25.32403628117914], [21.43403628117914, 22.27403628117914], [21.43403628117914, 25.32403628117914]]
denver broncos  ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'] [[9.896, 10.336, 11.116, 11.196]] [[10.29, 11.15]]
gold  ['gold', 'gold', 'gold'] [[29.576, 29.976, 30.156, 30.195999999999998]] [[28.448072562358277, 28.77807256235828]]
arabic numerals  ['"golden anniversary"', 'gold-themed', '"golden anniversary'] [[40.756, 40.896, 41.635999999999996, 41.716]] [[28.448072562358

# evaluate via AOS/FF1 metrics

In [17]:
if splt == "validation":
  len_squad = len(references)
else:
  len_squad = sum([1 for x in valid_word_dataset["question_audio_path"] if "squad" in x])
print("squad results")
res = evaluate(references[:len_squad], predictions[:len_squad],references_times[:len_squad],predictions_times[:len_squad])
print(res)
if splt == "test":
    print("OOF results")
    res = evaluate(references[len_squad:], predictions[len_squad:],references_times[len_squad:],predictions_times[len_squad:])
    print(res)

squad results
{'exact_match': 27.961498141618222, 'f1': 45.702214542086516, 'ff1': 64.45928227276359, 'aos': 59.21146849076837}
