In [3]:
import torch
import json
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.optim import AdamW
import jiwer
from sklearn.metrics import f1_score

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Data parsing function
def parse_data(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    
    contexts, questions, answers = [], [], []

    for item in data['data']:
        for paragraph in item['paragraphs']:
            context_text = paragraph['context']
            for qna in paragraph['qas']:
                question_text = qna['question']
                for answer in qna['answers']:
                    contexts.append(context_text.lower())
                    questions.append(question_text.lower())
                    answers.append(answer)
    
    return contexts, questions, answers

# Function to compute answer end positions
def compute_end_positions(answers, contexts):
    for answer, context in zip(answers, contexts):
        answer['text'] = answer['text'].lower()
        answer['end'] = answer['answer_start'] + len(answer['text'])

# Prepare inputs with tokenized start and end positions
def prepare_inputs(contexts, questions, answers, tokenizer, max_len):
    inputs = tokenizer(questions, contexts, max_length=max_len, padding="max_length", truncation=True)
    start_positions, end_positions = [], []

    for i, (answer, context) in enumerate(zip(answers, contexts)):
        answer_start = answer['answer_start']
        answer_end = answer['end']
        
        token_start = tokenizer.encode(context[:answer_start], add_special_tokens=False)
        token_end = tokenizer.encode(context[:answer_end], add_special_tokens=False)

        start_positions.append(len(token_start))
        end_positions.append(len(token_end) - 1)

    inputs.update({'start_positions': start_positions, 'end_positions': end_positions})
    return inputs

# Custom dataset class
class QuestionAnsweringDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, index):
        return {key: torch.tensor(val[index]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings['input_ids'])

# Load data and compute end positions
train_contexts, train_questions, train_answers = parse_data('spoken_train-v1.1.json')
valid_contexts, valid_questions, valid_answers = parse_data('spoken_test-v1.1.json')

compute_end_positions(train_answers, train_contexts)
compute_end_positions(valid_answers, valid_contexts)

# Tokenize data
MAX_LEN = 512
MODEL_NAME = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)

train_data = prepare_inputs(train_contexts, train_questions, train_answers, tokenizer, MAX_LEN)
valid_data = prepare_inputs(valid_contexts, valid_questions, valid_answers, tokenizer, MAX_LEN)

train_dataset = QuestionAnsweringDataset(train_data)
valid_dataset = QuestionAnsweringDataset(valid_data)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1)

# Model and optimizer
qa_model = DistilBertForQuestionAnswering.from_pretrained(MODEL_NAME).to(device)
optimizer = AdamW(qa_model.parameters(), lr=5e-5)

# Training function
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0

    for batch in tqdm(loader, desc="Training"):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_pos = batch["start_positions"].to(device)
        end_pos = batch["end_positions"].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_pos, end_positions=end_pos)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    return total_loss / len(loader)

# Evaluation function with F1 score calculation
def evaluate(model, loader):
    model.eval()
    word_error_rates = []
    all_true_spans = []
    all_pred_spans = []

    for batch in tqdm(loader, desc="Evaluating"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        true_start = batch["start_positions"].to(device)
        true_end = batch["end_positions"].to(device)
        
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        
        start_preds = torch.argmax(outputs.start_logits, dim=1)
        end_preds = torch.argmax(outputs.end_logits, dim=1)

        for i in range(len(true_start)):
            pred_ans = tokenizer.decode(input_ids[i][start_preds[i]:end_preds[i]+1])
            actual_ans = tokenizer.decode(input_ids[i][true_start[i]:true_end[i]+1])

            if actual_ans.strip():
                wer = jiwer.wer(actual_ans, pred_ans)
                word_error_rates.append(wer)

                # Token-level F1 calculation
                true_tokens = set(tokenizer.encode(actual_ans, add_special_tokens=False))
                pred_tokens = set(tokenizer.encode(pred_ans, add_special_tokens=False))
                all_true_spans.append(true_tokens)
                all_pred_spans.append(pred_tokens)

    avg_wer = sum(word_error_rates) / len(word_error_rates) if word_error_rates else 0.0

    # Calculate token-level F1 for the entire program
    all_true_tokens = set().union(*all_true_spans)
    all_pred_tokens = set().union(*all_pred_spans)
    true_positives = len(all_true_tokens & all_pred_tokens)
    precision = true_positives / len(all_pred_tokens) if all_pred_tokens else 0
    recall = true_positives / len(all_true_tokens) if all_true_tokens else 0
    final_f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0

    return avg_wer, final_f1_score

# Training loop with 3 epochs
EPOCHS = 3
for epoch in range(EPOCHS):
    avg_train_loss = train_one_epoch(qa_model, train_loader, optimizer)
    print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.4f}")

# Final evaluation with F1 score
wer, final_f1_score = evaluate(qa_model, valid_loader)
print(f"Final Evaluation - WER: {wer:.4f}, F1 Score: {final_f1_score:.4f}")

Token indices sequence length is longer than the specified maximum sequence length for this model (593 > 512). Running this sequence through the model will result in indexing errors
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training: 100%|██████████| 2320/2320 [09:03<00:00,  4.27it/s]


Epoch 1/3 - Train Loss: 4.1328


Training: 100%|██████████| 2320/2320 [09:04<00:00,  4.26it/s]


Epoch 2/3 - Train Loss: 3.0676


Training:   1%|          | 18/2320 [00:04<09:01,  4.25it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Training: 100%|██████████| 2320/2320 [09:04<00:00,  4.26it/s]


Epoch 3/3 - Train Loss: 2.3063


Evaluating: 100%|██████████| 15875/15875 [01:48<00:00, 146.76it/s]

Final Evaluation - WER: 2.0521, F1 Score: 0.6840



