In [1]:
import json
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AlbertModel, AlbertTokenizerFast
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score
from jiwer import wer
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def get_data(path): 
    with open(path, 'rb') as f:
        raw_data = json.load(f)
    contexts = []
    questions = []
    answers = []
    num_q = 0

    for group in raw_data['data']:
        for paragraph in group['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                question = qa['question']
                num_q += 1
                for answer in qa['answers']:
                    contexts.append(context.lower())
                    questions.append(question.lower())
                    answers.append(answer)
    return num_q, contexts, questions, answers

num_q, train_contexts, train_questions, train_answers = get_data('Spoken-SQuAD-master/spoken_train-v1.1.json')
num_q_valid, valid_contexts, valid_questions, valid_answers = get_data('Spoken-SQuAD-master/spoken_test-v1.1.json')

def add_answer_end(answers):
    for answer in answers:
        answer['text'] = answer['text'].lower()
        answer['answer_end'] = answer['answer_start'] + len(answer['text'])

add_answer_end(train_answers)
add_answer_end(valid_answers)
MAX_LENGTH = 512
MODEL_PATH = "albert-base-v2"
doc_stride = 128

tokenizerFast = AlbertTokenizerFast.from_pretrained(MODEL_PATH)
train_encodings_fast = tokenizerFast(train_questions, train_contexts, max_length=MAX_LENGTH, truncation=True, stride=doc_stride, padding=True)
valid_encodings_fast = tokenizerFast(valid_questions, valid_contexts, max_length=MAX_LENGTH, truncation=True, stride=doc_stride, padding=True)

def ret_Answer_start_and_end(idx, answers, encodings):
    ret_start = 0
    ret_end = 0
    answer_encoding_fast = tokenizerFast(answers[idx]['text'], max_length=MAX_LENGTH, truncation=True, padding=True)
    
    # Track if a match was found
    match_found = False
    
    for a in range(len(encodings['input_ids'][idx]) - len(answer_encoding_fast['input_ids'])):
        match = True
        for i in range(1, len(answer_encoding_fast['input_ids']) - 1):
            if answer_encoding_fast['input_ids'][i] != encodings['input_ids'][idx][a + i]:
                match = False
                break
        if match:
            ret_start = a + 1
            ret_end = a + len(answer_encoding_fast['input_ids']) - 1  # Corrected end position
            match_found = True
            break
            
    if not match_found:
        ret_start = 0
        ret_end = 0
    return ret_start, ret_end

train_start_positions = []
train_end_positions = []
for h in range(len(train_encodings_fast['input_ids'])):
    s, e = ret_Answer_start_and_end(h, train_answers, train_encodings_fast)
    train_start_positions.append(s)
    train_end_positions.append(e)

train_encodings_fast.update({'start_positions': train_start_positions, 'end_positions': train_end_positions})

valid_start_positions = []
valid_end_positions = []
for h in range(len(valid_encodings_fast['input_ids'])):
    s, e = ret_Answer_start_and_end(h, valid_answers, valid_encodings_fast)
    valid_start_positions.append(s)
    valid_end_positions.append(e)

valid_encodings_fast.update({'start_positions': valid_start_positions, 'end_positions': valid_end_positions})

class InputDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, i):
        return {
            'input_ids': torch.tensor(self.encodings['input_ids'][i]),
            'token_type_ids': torch.tensor(self.encodings['token_type_ids'][i]),
            'attention_mask': torch.tensor(self.encodings['attention_mask'][i]),
            'start_positions': torch.tensor(self.encodings['start_positions'][i]),
            'end_positions': torch.tensor(self.encodings['end_positions'][i])
        }

    def __len__(self):
        return len(self.encodings['input_ids'])

train_dataset = InputDataset(train_encodings_fast)
valid_dataset = InputDataset(valid_encodings_fast)

train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=1)

# Loading ALBERT model
albert_model = AlbertModel.from_pretrained(MODEL_PATH)

class QAModel(nn.Module):
    def __init__(self):
        super(QAModel, self).__init__()
        self.albert = albert_model
        self.drop_out = nn.Dropout(0.1)
        self.l1 = nn.Linear(768 * 2, 768 * 2)
        self.l2 = nn.Linear(768 * 2, 2)
        self.linear_relu_stack = nn.Sequential(
            self.drop_out,
            self.l1,
            nn.LeakyReLU(),
            self.l2 
        )
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        model_output = self.albert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        hidden_states = model_output[2]
        out = torch.cat((hidden_states[-1], hidden_states[-3]), dim=-1)
        logits = self.linear_relu_stack(out)
        
        start_logits, end_logits = logits.split(1, dim=-1)
        
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

model = QAModel()

def focal_loss_fn(start_logits, end_logits, start_positions, end_positions, gamma):
    smax = nn.Softmax(dim=1)
    probs_start = smax(start_logits)
    inv_probs_start = 1 - probs_start
    probs_end = smax(end_logits)
    inv_probs_end = 1 - probs_end
    lsmax = nn.LogSoftmax(dim=1)
    log_probs_start = lsmax(start_logits)
    log_probs_end = lsmax(end_logits)
    nll = nn.NLLLoss()
    fl_start = nll(torch.pow(inv_probs_start, gamma) * log_probs_start, start_positions)
    fl_end = nll(torch.pow(inv_probs_end, gamma) * log_probs_end, end_positions)
    return ((fl_start + fl_end) / 2)

optim = AdamW(model.parameters(), lr=2e-5, weight_decay=2e-2)
total_acc = []
total_loss = []

def train_epoch(model, dataloader, epoch):
    model.train()
    losses = []
    acc = []
    for batch in tqdm(dataloader, desc='Running Epoch '):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        out_start, out_end = model(input_ids=input_ids, 
                                   attention_mask=attention_mask,
                                   token_type_ids=token_type_ids)

        loss = focal_loss_fn(out_start, out_end, start_positions, end_positions, 1)
        losses.append(loss.item())
        loss.backward()
        optim.step()
        
        start_pred = torch.argmax(out_start, dim=1)
        end_pred = torch.argmax(out_end, dim=1)
            
        acc.append(((start_pred == start_positions).sum() / len(start_pred)).item())
        acc.append(((end_pred == end_positions).sum() / len(end_pred)).item())

    ret_acc = sum(acc) / len(acc)
    ret_loss = sum(losses) / len(losses)
    return (ret_acc, ret_loss)

def eval_model(model, dataloader):
    model.eval()
    answer_list = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Running Evaluation'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            start_true = batch['start_positions'].to(device)
            end_true = batch['end_positions'].to(device)
            
            out_start, out_end = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

            start_pred = torch.argmax(out_start)
            end_pred = torch.argmax(out_end)
            answer = tokenizerFast.convert_tokens_to_string(tokenizerFast.convert_ids_to_tokens(input_ids[0][start_pred:end_pred]))
            tanswer = tokenizerFast.convert_tokens_to_string(tokenizerFast.convert_ids_to_tokens(input_ids[0][start_true[0]:end_true[0]]))
            answer_list.append([answer, tanswer])

    return answer_list

EPOCHS = 4
model.to(device)
wer_list = []
f1_scores = []

for epoch in range(EPOCHS):
    train_acc, train_loss = train_epoch(model, train_data_loader, epoch + 1)
    print('Epoch - {}'.format(epoch))
    print(f"Accuracy: {train_acc}")
    print(f"Loss: {train_loss}")
    answer_list = eval_model(model, valid_data_loader)
    pred_answers = []
    true_answers = []
    for i in range(len(answer_list)):
        if len(answer_list[i][0]) == 0:
            answer_list[i][0] = "$"
        if len(answer_list[i][1]) == 0:
            answer_list[i][1] = "$"
        pred_answers.append(answer_list[i][0])
        true_answers.append(answer_list[i][1])
    wer_score = wer(true_answers, pred_answers)
    wer_list.append(wer_score)
    f1 = f1_score(true_answers, pred_answers, average='weighted')
    f1_scores.append(f1)
    print(f"F1 Score: {f1}")
print('WER - ', wer_list)
print('F1 Scores (per epoch)- ', f1_scores)

  torch.utils._pytree._register_pytree_node(
Running Epoch : 100%|██████████| 9278/9278 [22:05<00:00,  7.00it/s]


Epoch - 0
Accuracy: 0.5989706833385321
Loss: 1.2075854284208212


Running Evaluation: 100%|██████████| 15875/15875 [04:01<00:00, 65.79it/s]


F1 Score: 0.4519987504819997


Running Epoch : 100%|██████████| 9278/9278 [21:47<00:00,  7.10it/s]


Epoch - 1
Accuracy: 0.7044621685708127
Loss: 0.7604748443000727


Running Evaluation: 100%|██████████| 15875/15875 [03:58<00:00, 66.55it/s]


F1 Score: 0.46626100444024215


Running Epoch : 100%|██████████| 9278/9278 [21:46<00:00,  7.10it/s]


Epoch - 2
Accuracy: 0.7672855141194223
Loss: 0.5331318821653838


Running Evaluation: 100%|██████████| 15875/15875 [03:56<00:00, 67.01it/s]


F1 Score: 0.4745086637918573


Running Epoch : 100%|██████████| 9278/9278 [21:45<00:00,  7.11it/s]


Epoch - 3
Accuracy: 0.8115569088181613
Loss: 0.3948216319307098


Running Evaluation: 100%|██████████| 15875/15875 [03:57<00:00, 66.81it/s]


F1 Score: 0.4561390520508842
WER -  [0.8783082888996607, 0.7918565196316044, 0.8984730974309258, 0.9077556955889481]
F1 Scores (per epoch)-  [0.4519987504819997, 0.46626100444024215, 0.4745086637918573, 0.4561390520508842]
