In [None]:
    
import json
import torch
from tqdm import tqdm
from transformers import BertTokenizerFast, BertForQuestionAnswering, AdamW, get_linear_schedule_with_warmup
from sklearn.metrics import f1_score
import jiwer
import re
import string

# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load SQuAD Data
def load_data(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

# Extract Context, Questions, and Answers
def parse_data(data):
    contexts, questions, answers = [], [], []
    for group in data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                question = qa['question']
                contexts.extend([context] * len(qa['answers']))
                questions.extend([question] * len(qa['answers']))
                answers.extend(qa['answers'])
    return contexts, questions, answers

train_data = load_data('spoken_train-v1.1.json')
valid_data = load_data('spoken_test-v1.1.json')
train_contexts, train_questions, train_answers = parse_data(train_data)
valid_contexts, valid_questions, valid_answers = parse_data(valid_data)

# Adjust Answer Indices
def adjust_indices(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_index = answer.get('answer_start', 0)
        end_index = start_index + len(gold_text)
        
        for shift in range(3):
            shifted_start = start_index - shift
            shifted_end = shifted_start + len(gold_text)
            if context[shifted_start:shifted_end] == gold_text:
                answer['answer_start'] = shifted_start
                answer['answer_end'] = shifted_end
                break

adjust_indices(train_answers, train_contexts)
adjust_indices(valid_answers, valid_contexts)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# Tokenize Data
doc_stride=128

def tokenize_and_encode(contexts, questions):
    return tokenizer(contexts, questions, truncation=True, padding=True, max_length=512, stride=doc_stride)

train_encodings = tokenize_and_encode(train_contexts, train_questions)
valid_encodings = tokenize_and_encode(valid_contexts, valid_questions)

def add_positions(encodings, answers):
    start_positions, end_positions = [], []
    for i, answer in enumerate(answers):
        start_pos = max(0, answer['answer_start'])  # Ensure start_pos is non-negative
        end_pos = max(0, answer['answer_end'] - 1)  # Ensure end_pos is non-negative

        # Map character positions to token positions
        start_token = encodings.char_to_token(i, start_pos)
        end_token = encodings.char_to_token(i, end_pos)

        # If char_to_token returns None, set to tokenizer.model_max_length as a fallback
        if start_token is None:
            start_token = tokenizer.model_max_length - 1
        if end_token is None:
            end_token = tokenizer.model_max_length - 1

        start_positions.append(start_token)
        end_positions.append(end_token)

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

# Apply the function with the modified handling
add_positions(train_encodings, train_answers)
add_positions(valid_encodings, valid_answers)

# Dataset Class
class QADataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    
    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = QADataset(train_encodings)
valid_dataset = QADataset(valid_encodings)

# Data Loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=16)

# Model and Optimizer Setup
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased").to(device)
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_loader) * 3)

# Training Loop with F1 and Accuracy Calculation
for epoch in range(5):
    model.train()
    total_loss, correct_start, correct_end, total_samples = 0, 0, 0, 0
    all_start_preds, all_start_trues, all_end_preds, all_end_trues = [], [], [], []
    
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        start_pred = torch.argmax(outputs.start_logits, dim=1)
        end_pred = torch.argmax(outputs.end_logits, dim=1)
        
        correct_start += (start_pred == start_positions).sum().item()
        correct_end += (end_pred == end_positions).sum().item()
        total_samples += len(start_positions)
        
        all_start_preds.extend(start_pred.cpu().numpy())
        all_start_trues.extend(start_positions.cpu().numpy())
        all_end_preds.extend(end_pred.cpu().numpy())
        all_end_trues.extend(end_positions.cpu().numpy())

        loop.set_description(f'Epoch {epoch+1}')
        loop.set_postfix(loss=loss.item())
    
    accuracy = (correct_start + correct_end) / (2 * total_samples)
    avg_loss = total_loss / len(train_loader)
    start_f1 = f1_score(all_start_trues, all_start_preds, average='macro')
    end_f1 = f1_score(all_end_trues, all_end_preds, average='macro')
    overall_f1 = (start_f1 + end_f1) / 2
    
    print(f'Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, Start F1: {start_f1:.4f}, End F1: {end_f1:.4f}, Overall F1: {overall_f1:.4f}')


In [None]:
# Prediction function
def get_prediction(context, question):
    inputs = tokenizer.encode_plus(question, context, return_tensors='pt', max_length=512, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        answer_start = torch.argmax(outputs.start_logits)
        answer_end = torch.argmax(outputs.end_logits) + 1
        answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
    return answer

def prediction(data, model):
    pred_answers = {}
    for entry in data['data']:
        for paragraph in entry['paragraphs']: 
            context = paragraph['context']
            for qa in paragraph['qas']:
                question_id = qa['id']
                question = qa['question']
                pred_answers[question_id] = get_prediction(context, question)
    return pred_answers

# Compute WER function
def compute_wer(ground_truth, predicted):
    return jiwer.wer(ground_truth, predicted)

# Load validation data
with open("spoken_test-v1.1_WER44.json", 'r') as file:
    validation_data = json.load(file)

# Generate predictions
predicted_answers = prediction(validation_data, model)

# Initialize variables for WER calculation
total_wer = 0
question_count = 0

# Compute WER for each prediction
for question_id, predicted_answer in predicted_answers.items():
    found = False
    for entry in validation_data['data']:
        for paragraph in entry['paragraphs']:
            for qa_pair in paragraph['qas']:
                if qa_pair['id'] == question_id:
                    ground_truth = qa_pair['answers'][0]['text']
                    wer = compute_wer(ground_truth, predicted_answer)
                    total_wer += wer
                    question_count += 1
                    found = True
                    break
            if found:
                break
        if found:
            break

# Calculate and print cumulative WER
if question_count > 0:
    cumulative_wer = total_wer / question_count
    print(f"Cumulative WER: {cumulative_wer:.4f}")
else:
    print("No questions found to compute WER.")

In [None]:
with open("spoken_test-v1.1_WER54.json", 'r') as j:
    valContents = json.loads(j.read())
    predAnswers = prediction(valContents, model)
totalWER = 0
numQuestions = 0
for qid, predAnswer in predAnswers.items():
    for data in valContents['data']:
        for txt in data['paragraphs']:
            for qa in txt['qas']:
                if qa['id'] == qid:
                    groundTruth = qa['answers'][0]['text']
                    wer = compute_wer(groundTruth, predAnswer)
                    totalWER += wer
                    numQuestions += 1
                    break
            else:
                continue
            break
        else:
            continue
        break

cWER = totalWER / numQuestions
print(f"CumulativeWER: {cWER}")
