In [68]:
import requests
import json
import torch
import os
from tqdm import tqdm

with open('spoken_train-v1.1.json', 'rb') as f:
  squad = json.load(f)
squad['data'][0].keys()

dict_keys(['title', 'paragraphs'])

In [69]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'{device}')

cuda


In [70]:
def read_data(path):  
  # load the json file
  with open(path, 'rb') as f:
    squad = json.load(f)
  contexts = []
  questions = []
  answers = []
  for group in squad['data']:
    for passage in group['paragraphs']:
      context = passage['context']
      for qa in passage['qas']:
        question = qa['question']
        for answer in qa['answers']:
          contexts.append(context)
          questions.append(question)
          answers.append(answer)
  return contexts, questions, answers
train_contexts, train_questions, train_answers = read_data('spoken_train-v1.1.json')
valid_contexts, valid_questions, valid_answers = read_data('spoken_test-v1.1.json')
     


In [71]:
def add_end_idx(answers, contexts):
  for answer, context in zip(answers, contexts):
    gold_text = answer['text']
    start_idx = answer['answer_start']
    end_idx = start_idx + len(gold_text)

    if context[start_idx:end_idx] == gold_text:
      answer['answer_end'] = end_idx
    elif context[start_idx-1:end_idx-1] == gold_text:
      answer['answer_start'] = start_idx - 1
      answer['answer_end'] = end_idx - 1    
    elif context[start_idx-2:end_idx-2] == gold_text:
      answer['answer_start'] = start_idx - 2
      answer['answer_end'] = end_idx - 2   
add_end_idx(train_answers, train_contexts)
add_end_idx(valid_answers, valid_contexts)

In [72]:
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

doc_stride = 128 

train_encodings = tokenizer(
    train_contexts,
    train_questions,
    truncation=True,
    padding=True,
    max_length=512,
    stride=doc_stride
)

valid_encodings = tokenizer(
    valid_contexts,
    valid_questions,
    truncation=True,
    padding=True,
    max_length=512,
    stride=doc_stride
)

In [73]:
train_encodings.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [74]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i, answer in enumerate(answers):
        start_pos = max(0, answer['answer_start'])
        end_pos = max(0, answer['answer_end'] - 1)
        start_positions.append(encodings.char_to_token(i, start_pos))
        end_positions.append(encodings.char_to_token(i, end_pos))

        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
add_token_positions(train_encodings, train_answers)
add_token_positions(valid_encodings, valid_answers)


In [75]:
train_encodings['start_positions'][:10]

[36, 11, 55, 107, 25, 21, 43, 67, 28, 22]

In [76]:
class SQuAD_Dataset(torch.utils.data.Dataset):
  def __init__(self, encodings):
    self.encodings = encodings
  def __getitem__(self, idx):
    return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  def __len__(self):
    return len(self.encodings.input_ids)
     

In [77]:

train_dataset = SQuAD_Dataset(train_encodings)
valid_dataset = SQuAD_Dataset(valid_encodings)
     

In [78]:
from torch.utils.data import DataLoader


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16)

In [79]:
from transformers import BertForQuestionAnswering

model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [80]:
from transformers import AdamW
import torch
from tqdm import tqdm
from sklearn.metrics import f1_score

N_EPOCHS = 5
optim = AdamW(model.parameters(), lr=5e-5)

model.to(device)

for epoch in range(N_EPOCHS):
    model.train()
    total_loss = 0
    correct_start = 0
    correct_end = 0
    total_samples = 0

    # Store predictions and labels for F1 calculation
    all_start_preds = []
    all_start_trues = []
    all_end_preds = []
    all_end_trues = []

    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        # Forward pass with loss calculation
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        optim.step()

        # Accumulate loss
        total_loss += loss.item()

        # Get predicted start and end positions
        start_pred = torch.argmax(outputs[1], dim=1)
        end_pred = torch.argmax(outputs[2], dim=1)

        # Count correct predictions for start and end positions
        correct_start += (start_pred == start_positions).sum().item()
        correct_end += (end_pred == end_positions).sum().item()
        total_samples += len(start_positions)

        # Append predictions and true labels for F1 calculation
        all_start_preds.extend(start_pred.cpu().numpy())
        all_start_trues.extend(start_positions.cpu().numpy())
        all_end_preds.extend(end_pred.cpu().numpy())
        all_end_trues.extend(end_positions.cpu().numpy())

        # Update progress bar
        loop.set_description(f'Epoch {epoch+1}')
        loop.set_postfix(loss=loss.item())

    # Calculate training accuracy for this epoch
    accuracy = (correct_start + correct_end) / (2 * total_samples)
    avg_loss = total_loss / len(train_loader)

    # Calculate F1 scores for start and end predictions
    start_f1 = f1_score(all_start_trues, all_start_preds, average='macro')
    end_f1 = f1_score(all_end_trues, all_end_preds, average='macro')
    overall_f1 = (start_f1 + end_f1) / 2

    print(f'Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, Start F1: {start_f1:.4f}, End F1: {end_f1:.4f}, Overall F1: {overall_f1:.4f}')


Epoch 1:   1%|          | 24/2320 [00:08<14:09,  2.70it/s, loss=3.92] 


KeyboardInterrupt: 

In [81]:

model.eval()
accuracies = []

# Iterate over validation batches
for batch in tqdm(valid_loader):
    # Disable gradient calculation for inference
    with torch.no_grad():
        # Move input tensors to the appropriate device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)

        # Forward pass through the model
        outputs = model(input_ids, attention_mask=attention_mask)

        # Get the predicted start and end positions
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)

        # Calculate accuracy for both start and end predictions
        start_acc = (start_pred == start_true).float().mean().item()
        end_acc = (end_pred == end_true).float().mean().item()
        
        # Append both accuracies to the list
        accuracies.append((start_acc + end_acc) / 2)

  2%|▏         | 24/993 [00:02<01:42,  9.49it/s]


KeyboardInterrupt: 

In [82]:
import jiwer

model.eval()
acc = []
wer = []

for batch in tqdm(valid_loader):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        
        # Calculate accuracy
        acc.append(((start_pred == start_true).sum() / len(start_pred)).item())
        acc.append(((end_pred == end_true).sum() / len(end_pred)).item())
        
        # Calculate WER
        for i in range(len(start_true)):
            true_text = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]])
            pred_text = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]])
            if true_text.strip() == "":
                continue
            wer.append(jiwer.wer(true_text, pred_text))       
acc = sum(acc) / len(acc)
wer = sum(wer) / len(wer)
print(f'WER: {wer:.4f}')


  2%|▏         | 19/993 [00:02<01:45,  9.24it/s]


KeyboardInterrupt: 

In [83]:
def get_prediction(context, question):
  inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
  outputs = model(**inputs)
  answer_start = torch.argmax(outputs[0])  
  answer_end = torch.argmax(outputs[1]) + 1 
  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
  return answer

def normalize_text(s):
  import string, re
  def remove_articles(text):
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    return re.sub(regex, " ", text)
  def white_space_fix(text):
    return " ".join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()

  return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match(prediction, truth):
    return bool(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
  pred_tokens = normalize_text(prediction).split()
  truth_tokens = normalize_text(truth).split()
  
  # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
  if len(pred_tokens) == 0 or len(truth_tokens) == 0:
    return int(pred_tokens == truth_tokens)
  
  common_tokens = set(pred_tokens) & set(truth_tokens)
  
  # if there are no common tokens then f1 = 0
  if len(common_tokens) == 0:
    return 0
  
  prec = len(common_tokens) / len(pred_tokens)
  rec = len(common_tokens) / len(truth_tokens)
  return round(2*(prec*rec) / (prec + rec), 2)
  
def question_answer(context, question,answer):
  prediction = get_prediction(context,question)
  em_score = exact_match(prediction, answer)
  f1_score = compute_f1(prediction, answer)

def get_prediction(text, question):
  inputs = tokenizer.encode_plus(question, text, return_tensors='pt',max_length=512, truncation=True).to(device)

  outputs = model(**inputs)
  answer_start = torch.argmax(outputs[0])
  answer_end = torch.argmax(outputs[1]) + 1 

  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
  return answer
def prediction(contents,model):
  predAnswers={}
  for data in contents['data']:
    for txt in data['paragraphs']: 
      text = txt['context']
      for qa in txt['qas']:
        qid = qa['id']
        question = qa['question']
        predAnswers[qid]=get_prediction(text,question)
  return predAnswers

with open("spoken_test-v1.1_WER44.json", 'r') as j:
     valContents = json.loads(j.read())
     predAnswers=prediction(valContents,model)

with open('results.txt', 'w') as convert_file:
     convert_file.write(json.dumps(predAnswers))

In [84]:
import re
import string
def normalize_answer(s):   
    def lower(text):
        return text.lower()
    def remove_punctuation(text):
        return "".join(ch for ch in text if ch not in string.punctuation)
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)
    def white_space_fix(text):
        return " ".join(text.split())
    return white_space_fix(remove_articles(remove_punctuation(lower(s))))
def compute_wer(ground_truth, predicted):
    ground_truth_words = ground_truth.split()
    predicted_words = predicted.split()
    matrix = [[j for j in range(len(predicted_words) + 1)]]
    matrix += [[i] + [0] * len(predicted_words) for i in range(1, len(ground_truth_words) + 1)]
    for i in range(1, len(ground_truth_words) + 1):
        for j in range(1, len(predicted_words) + 1):
            if ground_truth_words[i - 1] == predicted_words[j - 1]:
                matrix[i][j] = matrix[i - 1][j - 1]
            else:
                substitution = matrix[i - 1][j - 1] + 1
                insertion = matrix[i][j - 1] + 1
                deletion = matrix[i - 1][j] + 1
                matrix[i][j] = min(substitution, insertion, deletion)
    return matrix[-1][-1] / len(ground_truth_words)



In [None]:
with open("spoken_test-v1.1_WER44.json", 'r') as j:
    valContents = json.loads(j.read())
    predAnswers = prediction(valContents, model)

totalWER = 0
numQuestions = 0

for qid, predAnswer in predAnswers.items():
    for data in valContents['data']:
        for txt in data['paragraphs']:
            for qa in txt['qas']:
                if qa['id'] == qid:
                    groundTruth = qa['answers'][0]['text']
                    wer = compute_wer(groundTruth, predAnswer)
                    totalWER += wer
                    numQuestions += 1
                    break
            else:
                continue
            break
        else:
            continue
        break

cWER = totalWER / numQuestions
print(f"WER: {cWER}")


In [None]:
with open("spoken_test-v1.1_WER54.json", 'r') as j:
    valContents = json.loads(j.read())
    predAnswers = prediction(valContents, model)
totalWER = 0
numQuestions = 0
for qid, predAnswer in predAnswers.items():
    for data in valContents['data']:
        for txt in data['paragraphs']:
            for qa in txt['qas']:
                if qa['id'] == qid:
                    groundTruth = qa['answers'][0]['text']
                    wer = compute_wer(groundTruth, predAnswer)
                    totalWER += wer
                    numQuestions += 1
                    break
            else:
                continue
            break
        else:
            continue
        break

cWER = totalWER / numQuestions
print(f"WER: {cWER}")
