In [None]:
import json
import torch
import torch.nn as nn
import os
from tqdm import tqdm
from transformers import BertModel, BertTokenizerFast, AdamW
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

train_data_path = 'spoken_train-v1.1.json'
test_data_path = 'spoken_test-v1.1.json'
fig_save_path = 'figs'
MAX_LENGTH = 512
MODEL_PATH = "bert-base-uncased"

# Function to load data and preprocess it
def load_data(path): 
    with open(path, 'rb') as f:
        raw_data = json.load(f)
    contexts, questions, answers = [], [], []

    for group in raw_data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context'].lower()
            for qa in paragraph['qas']:
                question = qa['question'].lower()
                for answer in qa['answers']:
                    answer['text'] = answer['text'].lower()
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
    return contexts, questions, answers

train_contexts, train_questions, train_answers = load_data(train_data_path)
valid_contexts, valid_questions, valid_answers = load_data(test_data_path)

# Calculate end positions for answers
def add_answer_end(answers, contexts):
    for answer, context in zip(answers, contexts):
        answer['answer_end'] = answer['answer_start'] + len(answer['text'])

add_answer_end(train_answers, train_contexts)
add_answer_end(valid_answers, valid_contexts)

# Tokenizer initialization
tokenizerFast = BertTokenizerFast.from_pretrained(MODEL_PATH)

# Tokenization and Encoding
def encode_data(questions, contexts, answers, tokenizer, max_length=MAX_LENGTH):
    encodings = tokenizer(questions, contexts, max_length=max_length, truncation=True, padding=True)
    start_positions, end_positions = [], []

    for idx in range(len(answers)):
        answer_encoding = tokenizer(answers[idx]['text'], truncation=True, padding=True)
        ans_ids = answer_encoding['input_ids'][1:-1]  # exclude special tokens
        for i in range(len(encodings['input_ids'][idx]) - len(ans_ids) + 1):
            if encodings['input_ids'][idx][i:i+len(ans_ids)] == ans_ids:
                start_positions.append(i)
                end_positions.append(i + len(ans_ids) - 1)
                break
        else:
            start_positions.append(0)
            end_positions.append(0)

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    return encodings

train_encodings = encode_data(train_questions, train_contexts, train_answers, tokenizerFast)
valid_encodings = encode_data(valid_questions, valid_contexts, valid_answers, tokenizerFast)

# Dataset class
class SQuAD_Dataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, i):
        return {k: torch.tensor(v[i]) for k, v in self.encodings.items()}
    def __len__(self):
        return len(self.encodings['input_ids'])

train_dataset = SQuAD_Dataset(train_encodings)
valid_dataset = SQuAD_Dataset(valid_encodings)

train_data_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=1)

# Model class
class QAModel(nn.Module):
    def __init__(self):
        super(QAModel, self).__init__()
        self.bert = BertModel.from_pretrained(MODEL_PATH)
        self.drop_out = nn.Dropout(0.1)
        self.linear = nn.Linear(768 * 2, 2)
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        hidden_states = torch.cat((outputs.hidden_states[-1], outputs.hidden_states[-3]), dim=-1)
        logits = self.linear(self.drop_out(hidden_states))
        start_logits, end_logits = logits.split(1, dim=-1)
        return start_logits.squeeze(-1), end_logits.squeeze(-1)

model = QAModel().to(device)

# Focal loss function
def focal_loss(start_logits, end_logits, start_positions, end_positions, gamma=1):
    smax = nn.Softmax(dim=1)
    log_probs_start = torch.log(smax(start_logits))
    log_probs_end = torch.log(smax(end_logits))
    nll = nn.NLLLoss()
    fl_start = nll((1 - smax(start_logits)).pow(gamma) * log_probs_start, start_positions)
    fl_end = nll((1 - smax(end_logits)).pow(gamma) * log_probs_end, end_positions)
    return (fl_start + fl_end) / 2

optim = AdamW(model.parameters(), lr=2e-5, weight_decay=2e-2)

# Training function
def train_model(model, dataloader):
    model.train()
    total_loss, total_acc = 0, 0
    for batch in tqdm(dataloader, desc='Training'):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        start_logits, end_logits = model(input_ids, attention_mask, token_type_ids)
        loss = focal_loss(start_logits, end_logits, start_positions, end_positions)
        loss.backward()
        optim.step()
        
        total_loss += loss.item()
        total_acc += ((torch.argmax(start_logits, dim=1) == start_positions).float().mean().item() +
                      (torch.argmax(end_logits, dim=1) == end_positions).float().mean().item()) / 2
    return total_loss / len(dataloader), total_acc / len(dataloader)

# Evaluation function
def evaluate_model(model, dataloader):
    model.eval()
    predictions, true_answers = [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            start_true = batch['start_positions'].to(device)
            end_true = batch['end_positions'].to(device)
            
            start_logits, end_logits = model(input_ids, attention_mask, token_type_ids)
            start_pred = torch.argmax(start_logits, dim=1)
            end_pred = torch.argmax(end_logits, dim=1)
            
            for i in range(len(start_pred)):
                answer = tokenizerFast.decode(input_ids[i][start_pred[i]:end_pred[i]+1])
                true_answer = tokenizerFast.decode(input_ids[i][start_true[i]:end_true[i]+1])
                predictions.append(answer if answer else "$")
                true_answers.append(true_answer if true_answer else "$")
    
    from evaluate import load
    wer = load("wer")
    wer_score = wer.compute(predictions=predictions, references=true_answers)
    return wer_score

# Main training loop
EPOCHS = 6
for epoch in range(EPOCHS):
    train_loss, train_acc = train_model(model, train_data_loader)
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Train Accuracy = {train_acc:.4f}")
    wer_score = evaluate_model(model, valid_data_loader)
    print(f"Epoch {epoch+1}: WER = {wer_score:.4f}")
