In [None]:
import torch
import json
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import jiwer
from sklearn.metrics import f1_score

# Setup device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load and process data
def read_data(filepath):
    with open(filepath, 'r') as f:
        data = json.load(f)
    contexts, questions, answers = [], [], []

    for group in data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context.lower())
                    questions.append(question.lower())
                    answers.append(answer)

    return contexts, questions, answers

def set_answer_boundaries(answers, contexts):
    for answer, context in zip(answers, contexts):
        answer['text'] = answer['text'].lower()
        answer['end_position'] = answer['answer_start'] + len(answer['text'])

# Dataset and tokenization
class QuestionAnswerDataset(Dataset):
    def __init__(self, encodings, answers):
        self.encodings = encodings
        self.answers = answers

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['start_positions'] = torch.tensor(self.answers[idx]['answer_start'])
        item['end_positions'] = torch.tensor(self.answers[idx]['end_position'])
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

# Load the training and validation data
train_contexts, train_questions, train_answers = read_data('spoken_train-v1.1.json')
valid_contexts, valid_questions, valid_answers = read_data('spoken_test-v1.1.json')

set_answer_boundaries(train_answers, train_contexts)
set_answer_boundaries(valid_answers, valid_contexts)

# Model and tokenizer setup
MODEL_NAME = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)
MAX_SEQ_LEN = 512
DOC_STRIDE = 128

train_encodings = tokenizer(train_questions, train_contexts, max_length=MAX_SEQ_LEN, padding='max_length', truncation=True)
valid_encodings = tokenizer(valid_questions, valid_contexts, max_length=MAX_SEQ_LEN, padding='max_length', truncation=True)

train_data = QuestionAnswerDataset(train_encodings, train_answers)
valid_data = QuestionAnswerDataset(valid_encodings, valid_answers)

train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=1)

# Model initialization
model = DistilBertForQuestionAnswering.from_pretrained(MODEL_NAME).to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)
EPOCHS = 3
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Training and evaluation functions
def train_step(model, dataloader, optimizer, scheduler):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc='Training'):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()

    return total_loss / len(dataloader)

def compute_wer(predicted_text, true_text):
    if true_text.strip():
        return jiwer.wer(true_text, predicted_text)
    return 0.0

def evaluate(model, dataloader):
    model.eval()
    all_true_spans = []
    all_pred_spans = []
    wer_scores = []

    for batch in tqdm(dataloader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        with torch.no_grad():
            output = model(input_ids=input_ids, attention_mask=attention_mask)
        
        pred_start = torch.argmax(output.start_logits, dim=1)
        pred_end = torch.argmax(output.end_logits, dim=1)

        for i in range(len(input_ids)):
            predicted_text = tokenizer.decode(input_ids[i][pred_start[i]:pred_end[i]+1])
            true_text = tokenizer.decode(input_ids[i][start_positions[i]:end_positions[i]+1])
            wer = compute_wer(predicted_text, true_text)
            wer_scores.append(wer)

            # Token-level F1 score calculation
            true_span_tokens = set(tokenizer.encode(true_text, add_special_tokens=False))
            pred_span_tokens = set(tokenizer.encode(predicted_text, add_special_tokens=False))
            all_true_spans.append(true_span_tokens)
            all_pred_spans.append(pred_span_tokens)

    avg_wer = sum(wer_scores) / len(wer_scores) if wer_scores else 0.0

    # Calculate token-level F1 score across all predictions
    all_true_tokens = set().union(*all_true_spans)
    all_pred_tokens = set().union(*all_pred_spans)
    true_positives = len(all_true_tokens & all_pred_tokens)
    precision = true_positives / len(all_pred_tokens) if all_pred_tokens else 0
    recall = true_positives / len(all_true_tokens) if all_true_tokens else 0
    f1_score_model = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0

    return avg_wer, f1_score_model

# Training loop
for epoch in range(EPOCHS):
    avg_train_loss = train_step(model, train_dataloader, optimizer, scheduler)
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {avg_train_loss}")

# Final evaluation for the entire model
wer, final_f1_score = evaluate(model, valid_dataloader)
print(f"Final Evaluation - WER: {wer}, F1 Score: {final_f1_score}")

  torch.utils._pytree._register_pytree_node(
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Training: 100%|██████████| 2320/2320 [09:02<00:00,  4.28it/s]


Epoch 1/3 - Loss: 5.883984201118864


Training: 100%|██████████| 2320/2320 [09:03<00:00,  4.27it/s]


Epoch 2/3 - Loss: 5.356978672126244


Training: 100%|██████████| 2320/2320 [09:03<00:00,  4.27it/s]


Epoch 3/3 - Loss: 4.911500815687509


Evaluating:  67%|██████▋   | 10585/15875 [01:14<00:37, 139.76it/s]