In [None]:
import json
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import BertModel, BertTokenizerFast, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from evaluate import load

# Set device for computation
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Paths and configuration
train_data_path = 'spoken_train-v1.1.json'
test_data_path = 'spoken_test-v1.1.json'
MAX_LENGTH = 512
MODEL_PATH = "bert-base-uncased"
doc_stride = 128

# Load and preprocess data
def load_data(path): 
    with open(path, 'r') as f:
        raw_data = json.load(f)
    contexts, questions, answers = [], [], []
    
    for group in raw_data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context'].lower()
            for qa in paragraph['qas']:
                question = qa['question'].lower()
                for answer in qa['answers']:
                    answer['text'] = answer['text'].lower()
                    answer['answer_end'] = answer['answer_start'] + len(answer['text'])
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
    return contexts, questions, answers

train_contexts, train_questions, train_answers = load_data(train_data_path)
valid_contexts, valid_questions, valid_answers = load_data(test_data_path)

# Truncate contexts if they exceed maximum length
def truncate_contexts(contexts, answers, max_length=MAX_LENGTH):
    truncated_contexts = []
    for i, context in enumerate(contexts):
        if len(context) > max_length:
            answer_start = answers[i]['answer_start']
            answer_end = answer_start + len(answers[i]['text'])
            mid = (answer_start + answer_end) // 2
            para_start = max(0, min(mid - max_length // 2, len(context) - max_length))
            truncated_contexts.append(context[para_start: para_start + max_length])
            answers[i]['answer_start'] = answer_start - para_start
        else:
            truncated_contexts.append(context)
    return truncated_contexts

train_contexts = truncate_contexts(train_contexts, train_answers)

# Tokenization and encoding
tokenizer = BertTokenizerFast.from_pretrained(MODEL_PATH)

def encode_data(questions, contexts, answers, tokenizer, max_length=MAX_LENGTH, stride=doc_stride):
    encodings = tokenizer(questions, contexts, max_length=max_length, truncation=True, stride=stride, padding=True)
    start_positions, end_positions = [], []
    
    for idx, answer in enumerate(answers):
        answer_ids = tokenizer(answer['text'], add_special_tokens=False)['input_ids']
        context_ids = encodings['input_ids'][idx]
        
        for i in range(len(context_ids) - len(answer_ids) + 1):
            if context_ids[i:i + len(answer_ids)] == answer_ids:
                start_positions.append(i)
                end_positions.append(i + len(answer_ids) - 1)
                break
        else:
            start_positions.append(0)
            end_positions.append(0)
    
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    return encodings

train_encodings = encode_data(train_questions, train_contexts, train_answers, tokenizer)
valid_encodings = encode_data(valid_questions, valid_contexts, valid_answers, tokenizer)

# Dataset class
class SQuADDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, i):
        return {k: torch.tensor(v[i]) for k, v in self.encodings.items()}
    def __len__(self):
        return len(self.encodings['input_ids'])

train_dataset = SQuADDataset(train_encodings)
valid_dataset = SQuADDataset(valid_encodings)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1)

# Model definition
class QAModel(nn.Module):
    def __init__(self):
        super(QAModel, self).__init__()
        self.bert = BertModel.from_pretrained(MODEL_PATH)
        self.drop_out = nn.Dropout(0.1)
        self.fc = nn.Linear(768 * 2, 2)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        hidden_states = torch.cat((outputs.hidden_states[-1], outputs.hidden_states[-3]), dim=-1)
        logits = self.fc(self.drop_out(hidden_states))
        return logits.split(1, dim=-1)

model = QAModel().to(device)

# Focal loss function
def focal_loss(start_logits, end_logits, start_positions, end_positions, gamma=1):
    smax = nn.Softmax(dim=1)
    inv_probs_start = 1 - smax(start_logits)
    inv_probs_end = 1 - smax(end_logits)
    
    log_probs_start = torch.log(smax(start_logits))
    log_probs_end = torch.log(smax(end_logits))
    
    nll = nn.NLLLoss()
    fl_start = nll((inv_probs_start.pow(gamma)) * log_probs_start, start_positions)
    fl_end = nll((inv_probs_end.pow(gamma)) * log_probs_end, end_positions)
    
    return (fl_start + fl_end) / 2

# Training and evaluation functions
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=2e-2)
scheduler = ExponentialLR(optimizer, gamma=0.9)

def train(model, dataloader, epoch):
    model.train()
    total_loss, total_acc = 0, 0
    
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1} Training"):
        optimizer.zero_grad()
        
        input_ids, attention_mask, token_type_ids = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['token_type_ids'].to(device)
        start_positions, end_positions = batch['start_positions'].to(device), batch['end_positions'].to(device)
        
        start_logits, end_logits = model(input_ids, attention_mask, token_type_ids)
        loss = focal_loss(start_logits.squeeze(-1), end_logits.squeeze(-1), start_positions, end_positions)
        
        loss.backward()
        optimizer.step()
        
        start_pred = torch.argmax(start_logits, dim=1)
        end_pred = torch.argmax(end_logits, dim=1)
        acc = ((start_pred == start_positions).float().mean().item() + (end_pred == end_positions).float().mean().item()) / 2
        total_loss += loss.item()
        total_acc += acc
    
    scheduler.step()
    return total_loss / len(dataloader), total_acc / len(dataloader)

def evaluate(model, dataloader):
    model.eval()
    predictions, references = [], []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids, attention_mask, token_type_ids = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['token_type_ids'].to(device)
            start_true, end_true = batch['start_positions'].to(device), batch['end_positions'].to(device)
            
            start_logits, end_logits = model(input_ids, attention_mask, token_type_ids)
            start_pred = torch.argmax(start_logits, dim=1)
            end_pred = torch.argmax(end_logits, dim=1)
            
            for i in range(len(start_pred)):
                pred_answer = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]+1])
                true_answer = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]+1])
                predictions.append(pred_answer if pred_answer else "$")
                references.append(true_answer if true_answer else "$")
    
    wer_metric = load("wer")
    wer_score = wer_metric.compute(predictions=predictions, references=references)
    return wer_score

# Main training loop
EPOCHS = 6
wer_list = []

for epoch in range(EPOCHS):
    train_loss, train_acc = train(model, train_loader, epoch)
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    
    wer_score = evaluate(model, valid_loader)
    wer_list.append(wer_score)
    print(f"WER: {wer_score:.4f}")

print("WER List:", wer_list)
