In [2]:
import torch
import json
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import jiwer
import logging

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

# Set device for training
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Function to read data file and parse questions, contexts, and answers
def read_data(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
    except FileNotFoundError:
        log.error(f"Unable to locate the file: {file_path}")
        raise

    contexts = []
    questions = []
    answers = []
    total_questions = 0
    possible_count = 0
    impossible_count = 0

    for section in data["data"]:
        for passage in section["paragraphs"]:
            context = passage["context"]
            for qna in passage["qas"]:
                question = qna["question"]
                total_questions += 1
                if "is_impossible" in qna and qna["is_impossible"]:
                    impossible_count += 1
                else:
                    possible_count += 1
                for ans in qna.get("answers", []):
                    contexts.append(context.lower())
                    questions.append(question.lower())
                    answers.append(ans)

    return total_questions, possible_count, impossible_count, contexts, questions, answers

# Load training and validation data
try:
    train_totals, train_possible, train_impossible, train_texts, train_qs, train_ans = read_data("spoken_train-v1.1.json")
    valid_totals, valid_possible, valid_impossible, valid_texts, valid_qs, valid_ans = read_data("spoken_test-v1.1.json")
except Exception as err:
    log.error(f"Data loading error: {err}")
    exit()

# Function to calculate answer end positions
def compute_answer_ends(answers, contexts):
    for ans, context in zip(answers, contexts):
        answer_txt = ans.get("text", "").lower()
        answer_start = ans.get("answer_start", -1)
        ans["answer_end"] = answer_start + len(answer_txt)

compute_answer_ends(train_ans, train_texts)
compute_answer_ends(valid_ans, valid_texts)

MAX_SEQ_LEN = 512
MODEL_NAME = "distilbert-base-uncased"

tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)

# Tokenize questions and contexts
train_enc = tokenizer(train_qs, train_texts, max_length=MAX_SEQ_LEN, padding=True, truncation=True)
valid_enc = tokenizer(valid_qs, valid_texts, max_length=MAX_SEQ_LEN, padding=True, truncation=True)

# Custom dataset class
class QuestionAnswerDataset(Dataset):
    def __init__(self, encodings, answer_data):
        self.encodings = encodings
        self.answer_data = answer_data

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["start_pos"] = torch.tensor(self.answer_data[idx].get("answer_start", -1))
        item["end_pos"] = torch.tensor(self.answer_data[idx].get("answer_end", -1))
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

# Initialize datasets and loaders
train_ds = QuestionAnswerDataset(train_enc, train_ans)
valid_ds = QuestionAnswerDataset(valid_enc, valid_ans)

train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=1)

qa_model = DistilBertForQuestionAnswering.from_pretrained(MODEL_NAME).to(device)
optim = AdamW(qa_model.parameters(), lr=5e-5)

# Training function for one epoch
def train_epoch(model, loader, optimizer):
    model.train()
    total_epoch_loss = 0.0
    for batch in tqdm(loader, desc="Training"):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_positions = batch["start_pos"].to(device)
        end_positions = batch["end_pos"].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, 
                        start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        total_epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

    return total_epoch_loss / len(loader)

# Evaluation function with WER and F1 scoring
def evaluate(model, loader):
    model.eval()
    wer_scores = []
    f1_scores = []
    precision_scores = []
    recall_scores = []

    for batch in tqdm(loader, desc="Evaluating"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        true_start = batch["start_pos"].to(device)
        true_end = batch["end_pos"].to(device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        start_preds = torch.argmax(outputs.start_logits, dim=1)
        end_preds = torch.argmax(outputs.end_logits, dim=1)
        
        for i in range(len(true_start)):
            pred_answer = tokenizer.decode(input_ids[i][start_preds[i]:end_preds[i]+1])
            true_answer = tokenizer.decode(input_ids[i][true_start[i]:true_end[i]+1])

            if true_answer.strip():
                wer_score = jiwer.wer(true_answer, pred_answer)
                wer_scores.append(wer_score)

                # Calculate F1, precision, and recall for the predicted and actual answer spans
                true_tokens = set(tokenizer.encode(true_answer, add_special_tokens=False))
                pred_tokens = set(tokenizer.encode(pred_answer, add_special_tokens=False))
                
                true_positives = len(true_tokens & pred_tokens)
                precision = true_positives / len(pred_tokens) if pred_tokens else 0
                recall = true_positives / len(true_tokens) if true_tokens else 0
                f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0
                
                precision_scores.append(precision)
                recall_scores.append(recall)
                f1_scores.append(f1)

    avg_wer = sum(wer_scores) / len(wer_scores) if wer_scores else 0.0
    avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
    avg_precision = sum(precision_scores) / len(precision_scores) if precision_scores else 0.0
    avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0.0

    log.info(f"Evaluation - WER: {avg_wer:.4f}, Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, F1 Score: {avg_f1:.4f}")
    return avg_wer, avg_precision, avg_recall, avg_f1

NUM_EPOCHS = 3  # Reduced from 5 to 3
lowest_wer = float("inf")
early_stop_patience = 3
early_stop_counter = 0

# Main training and evaluation loop
for epoch in range(NUM_EPOCHS):
    avg_train_loss = train_epoch(qa_model, train_dl, optim)
    eval_wer, _, _, _ = evaluate(qa_model, valid_dl)  # Ignore F1, precision, recall in per-epoch eval

    log.info(f"Epoch {epoch+1}/{NUM_EPOCHS}, Training Loss: {avg_train_loss}, Evaluation - WER: {eval_wer:.4f}")

    if eval_wer < lowest_wer:
        lowest_wer = eval_wer
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= early_stop_patience:
            log.info("Early stopping initiated.")
            break

# Final evaluation after training to get single F1 score
_, final_precision, final_recall, final_f1 = evaluate(qa_model, valid_dl)
log.info(f"Final Evaluation - Precision: {final_precision:.4f}, Recall: {final_recall:.4f}, F1 Score: {final_f1:.4f}")

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training: 100%|██████████| 2320/2320 [09:02<00:00,  4.28it/s]
Evaluating: 100%|██████████| 15875/15875 [01:56<00:00, 135.79it/s]
INFO:__main__:Evaluation - WER: 11.1133, Precision: 0.2998, Recall: 0.7329, F1 Score: 0.3254
INFO:__main__:Epoch 1/3, Training Loss: 5.980311680456688, Evaluation - WER: 11.1133
Training: 100%|██████████| 2320/2320 [17:03<00:00,  2.27it/s]
Evaluating: 100%|██████████| 15875/15875 [06:07<00:00, 43.22it/s] 
INFO:__main__:Evaluation - WER: 6.6106, Precision: 0.4313, Recall: 0.6887, F1 Score: 0.4420
INFO:__main__:Epoch 2/3, Training Loss: 5.480455759270438, Evaluation - WER: 6.6106
Training: 100%|██████████| 2320/2320 [16:17<00:00,  2.37it/s]
Evaluating: 100%|█████████