In [30]:
import requests
import json
import torch
import os
from tqdm import tqdm
import transformers

In [2]:
with open('spoken_train-v1.1.json', 'rb') as f:
  squad = json.load(f)

In [3]:
squad['data'][0].keys()


dict_keys(['title', 'paragraphs'])

In [4]:
def read_data(path):  
  # load the json file
  with open(path, 'rb') as f:
    squad = json.load(f)

  contexts = []
  questions = []
  answers = []

  for group in squad['data']:
    for passage in group['paragraphs']:
      context = passage['context']
      for qa in passage['qas']:
        question = qa['question']
        for answer in qa['answers']:
          contexts.append(context)
          questions.append(question)
          answers.append(answer)

  return contexts, questions, answers

In [5]:
train_contexts, train_questions, train_answers = read_data('spoken_train-v1.1.json')
valid_contexts, valid_questions, valid_answers = read_data('spoken_test-v1.1.json')
     


In [59]:
print(train_questions[-14000])
print(train_answers[-14000])

What did Hermann Ebbinghaus study?
{'answer_start': 406, 'text': 'memory studies', 'answer_end': 420}


In [49]:
def add_end_idx(answers, contexts):
  for answer, context in zip(answers, contexts):
    gold_text = answer['text']
    start_idx = answer['answer_start']
    end_idx = start_idx + len(gold_text)

    # sometimes squad answers are off by a character or two so we fix this
    if context[start_idx:end_idx] == gold_text:
      answer['answer_end'] = end_idx
    elif context[start_idx-1:end_idx-1] == gold_text:
      answer['answer_start'] = start_idx - 1
      answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
    elif context[start_idx-2:end_idx-2] == gold_text:
      answer['answer_start'] = start_idx - 2
      answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

add_end_idx(train_answers, train_contexts)
add_end_idx(valid_answers, valid_contexts)

In [50]:
print(train_questions[-14000])
print(train_answers[-14000])

What did Hermann Ebbinghaus study?
{'answer_start': 406, 'text': 'memory studies', 'answer_end': 420}


In [16]:
# from transformers import BertTokenizerFast

# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
# valid_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True)
     

In [51]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
doc_stride = 128  # Set the doc stride value as per your requirements
train_encodings = tokenizer(
    train_contexts,
    train_questions,
    truncation=True,
    padding=True,
    max_length=250,
    stride=doc_stride
)
valid_encodings = tokenizer(
    valid_contexts,
    valid_questions,
    truncation=True,
    padding=True,
    max_length=250,
    stride=doc_stride
)

In [10]:
train_encodings.keys()


dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [11]:
no_of_encodings = len(train_encodings['input_ids'])
no_of_valid_encodings = len(valid_encodings['input_ids'])
print(f'There are {no_of_encodings} context-question pairs in training dataset')
print(f'There are {no_of_valid_encodings} context-question pairs in testing dataset')   

There are 37111 context-question pairs in training dataset
There are 15875 context-question pairs in testing dataset


In [60]:

# Accessing the first record in train_encodings_fast
encoding = valid_encodings[0]

# Accessing the attributes of the encoding
ids = encoding.ids
tokens = encoding.tokens
offsets = encoding.offsets
attention_mask = encoding.attention_mask
special_tokens_mask = encoding.special_tokens_mask

# Printing the results
print("IDs: ", ids)
print("Tokens: ", tokens)
print("Offsets: ", offsets)
print("Attention Mask: ", attention_mask)
print("Special Tokens Mask: ", special_tokens_mask)


IDs:  [101, 3565, 4605, 5595, 2001, 2019, 2137, 2374, 2208, 2000, 5646, 1996, 3410, 1997, 1996, 2120, 2374, 2223, 5088, 2005, 1996, 3174, 5417, 2161, 1012, 1996, 2137, 2374, 3034, 1037, 1042, 1039, 1039, 3410, 7573, 14169, 3249, 1996, 2120, 2374, 3034, 1050, 1042, 1039, 1039, 3410, 3792, 12915, 3174, 2176, 2000, 2702, 2000, 7796, 2037, 2353, 3565, 4605, 2516, 1012, 1996, 2208, 2001, 2209, 2006, 2337, 5066, 3174, 7032, 1998, 11902, 2015, 3346, 1999, 1996, 2624, 3799, 3016, 2181, 4203, 10254, 2662, 1012, 2004, 2023, 2001, 1996, 10882, 6199, 2666, 2705, 3565, 4605, 1996, 2223, 13155, 1996, 3585, 5315, 2007, 2536, 25507, 2015, 11107, 2004, 2092, 2004, 8184, 28324, 2075, 1996, 4535, 1997, 10324, 2169, 3565, 4605, 2208, 2007, 3142, 16371, 28990, 2015, 2104, 2029, 2027, 5114, 2052, 2031, 2042, 2124, 2004, 3565, 4605, 1048, 5271, 2008, 1996, 8154, 2071, 14500, 2956, 1996, 5640, 16371, 28990, 2015, 5595, 1012, 102, 2029, 5088, 2136, 3421, 1996, 10511, 2012, 3565, 4605, 2753, 1029, 102, 0, 0, 0,

In [52]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i, answer in enumerate(answers):
        start_pos = max(0, answer['answer_start'])
        end_pos = max(0, answer['answer_end'] - 1)
        start_positions.append(encodings.char_to_token(i, start_pos))
        end_positions.append(encodings.char_to_token(i, end_pos))

        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

# add_token_positions(train_encodings, train_answers)
add_token_positions(valid_encodings, valid_answers)


In [61]:
valid_encodings['start_positions'][:10]


[34, 34, 34, 46, 46, 46, 79, 70, 34, 34]

In [62]:
class SQuAD_Dataset(torch.utils.data.Dataset):
  def __init__(self, encodings):
    self.encodings = encodings
  def __getitem__(self, idx):
    return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  def __len__(self):
    return len(self.encodings.input_ids)
     

In [63]:

train_dataset = SQuAD_Dataset(train_encodings)
valid_dataset = SQuAD_Dataset(valid_encodings)
 


In [64]:
from torch.utils.data import DataLoader

# Define the dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16)

    

In [21]:
from transformers import BertForQuestionAnswering

model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [65]:
valid_contexts, valid_questions, valid_answers = read_data('spoken_test-v1.1.json')


In [23]:

# Check on the available device - use GPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Working on {device}')
     

Working on cuda


In [47]:
# torch.cuda.empty_cache()

In [24]:
from transformers import AdamW

N_EPOCHS = 4
optim = AdamW(model.parameters(), lr=5e-5)

model.to(device)
model.train()

for epoch in range(N_EPOCHS):
  loop = tqdm(train_loader, leave=True)
  for batch in loop:
    optim.zero_grad()
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    start_positions = batch['start_positions'].to(device)
    end_positions = batch['end_positions'].to(device)
    outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
    loss = outputs[0]
    loss.backward()
    optim.step()

    loop.set_description(f'Epoch {epoch+1}')
    loop.set_postfix(loss=loss.item())

Epoch 1:  81%|████████  | 1874/2320 [05:22<01:16,  5.81it/s, loss=1.34] IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 1: 100%|██████████| 2320/2320 [06:39<00:00,  5.81it/s, loss=0.979]
Epoch 2: 100%|██████████| 2320/2320 [06:42<00:00,  5.77it/s, loss=1.49] 
Epoch 3: 100%|██████████| 2320/2320 [06:41<00:00,  5.78it/s, loss=0.402] 
Epoch 4: 100%|██████████| 2320/2320 [06:40<00:00,  5.79it/s, loss=0.481] 


In [41]:
loop = tqdm(train_loader, leave=True)
for batch in loop:
    optim.zero_grad()
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    start_positions = batch['start_positions'].to(device)
    end_positions = batch['end_positions'].to(device)

100%|██████████| 2320/2320 [00:04<00:00, 521.53it/s]


In [42]:
start_positions

tensor([ 32, 125,  64,  30,  46,   6, 155], device='cuda:0')

In [52]:
def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_answers = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
            loss = outputs[0]
            total_loss += loss.item()

            start_logits, end_logits = outputs[1], outputs[2]
            for i in range(len(start_logits)):
                start_pred = torch.argmax(start_logits[i]).item()
                end_pred = torch.argmax(end_logits[i]).item()
                prediction = tokenizer.decode(input_ids[i][start_pred:end_pred+1])
                all_predictions.append(prediction)
                all_answers.append(tokenizer.decode(input_ids[i][start_positions[i]:end_positions[i]+1]))

    f1 = compute_f1(all_predictions, all_answers)

    return total_loss/len(dataloader), f1


In [39]:
model_path = 'DL'
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

('DL/tokenizer_config.json',
 'DL/special_tokens_map.json',
 'DL/vocab.txt',
 'DL/added_tokens.json',
 'DL/tokenizer.json')

In [67]:
# model.eval()

# acc = []

# for batch in tqdm(valid_loader):
#   with torch.no_grad():
#     input_ids = batch['input_ids'].to(device)
#     attention_mask = batch['attention_mask'].to(device)
#     start_true = batch['start_positions'].to(device)
#     end_true = batch['end_positions'].to(device)
    
#     outputs = model(input_ids, attention_mask=attention_mask)

#     start_pred = torch.argmax(outputs['start_logits'], dim=1)
#     end_pred = torch.argmax(outputs['end_logits'], dim=1)

#     acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
#     acc.append(((end_pred == end_true).sum()/len(end_pred)).item())

# acc = sum(acc)/len(acc)

# print("\n\nT/P\tanswer_start\tanswer_end\n")
# for i in range(len(start_true)):
#   print(f"true\t{start_true[i]}\t{end_true[i]}\n"
#         f"pred\t{start_pred[i]}\t{end_pred[i]}\n")
     

In [73]:
def get_prediction(context, question):
  inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
  outputs = model(**inputs)
  
  answer_start = torch.argmax(outputs[0])  
  answer_end = torch.argmax(outputs[1]) + 1 
  
  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
  
  return answer

def normalize_text(s):
  """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
  import string, re
  def remove_articles(text):
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    return re.sub(regex, " ", text)
  def white_space_fix(text):
    return " ".join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()

  return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match(prediction, truth):
    return bool(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
  pred_tokens = normalize_text(prediction).split()
  truth_tokens = normalize_text(truth).split()
  
  # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
  if len(pred_tokens) == 0 or len(truth_tokens) == 0:
    return int(pred_tokens == truth_tokens)
  
  common_tokens = set(pred_tokens) & set(truth_tokens)
  
  # if there are no common tokens then f1 = 0
  if len(common_tokens) == 0:
    return 0
  
  prec = len(common_tokens) / len(pred_tokens)
  rec = len(common_tokens) / len(truth_tokens)
  
  return round(2 * (prec * rec) / (prec + rec), 2)

  
def question_answer(context, question,answer):
  prediction = get_prediction(context,question)
  em_score = exact_match(prediction, answer)
  f1_score = compute_f1(prediction, answer)

  print(f'Question: {question}')
  print(f'Prediction: {prediction}')
  print(f'True Answer: {answer}')
  print(f'Exact match: {em_score}')
  print(f'F1 score: {f1_score}\n')

In [28]:
# valid_contexts44, valid_questions44, valid_answers44 = read_data('spoken_test-v1.1_WER44.json')
    
# add_end_idx(valid_answers44, valid_contexts44)
# valid_encodings44 = tokenizer(
#     valid_contexts44,
#     valid_questions44,
#     truncation=True,
#     padding=True,
#     max_length=250,
#     stride=doc_stride)
# add_token_positions(valid_encodings44, valid_answers44)
# valid_dataset44 = SQuAD_Dataset(valid_encodings44)
# valid_loader44 = DataLoader(valid_dataset44, batch_size=8)

In [70]:
import jiwer

model.eval()
acc = []
wer = []

for batch in tqdm(valid_loader):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        
        # Calculate accuracy
        acc.append(((start_pred == start_true).sum() / len(start_pred)).item())
        acc.append(((end_pred == end_true).sum() / len(end_pred)).item())
        
        # Calculate WER
        for i in range(len(start_true)):
            true_text = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]])
            pred_text = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]])
            if true_text.strip() == "":
                continue
            wer.append(jiwer.wer(true_text, pred_text))
            f1_scores.append(compute_f1(true_text, pred_text))
        
acc = sum(acc) / len(acc)
wer = sum(wer) / len(wer)

print(f'Accuracy: {acc:.4f}')
print(f'WER: {wer:.4f}')


100%|██████████| 993/993 [00:55<00:00, 17.90it/s]

Accuracy: 0.5763
WER: 1.9152





In [72]:
overall_f1_score = np.mean(f1_scores)
print(f'F1 Score: {overall_f1_score:.4f}')

F1 Score: 0.5643


In [None]:
# Gradient Accumulation

In [69]:
from transformers import AdamW

N_EPOCHS = 4
GRAD_ACCUM_STEPS = 4
optim = AdamW(model.parameters(), lr=5e-5)

model.to(device)
model.train()

for epoch in range(N_EPOCHS):
  loop = tqdm(train_loader, leave=True)
  epoch_loss = 0
  for i, batch in enumerate(loop):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    start_positions = batch['start_positions'].to(device)
    end_positions = batch['end_positions'].to(device)
    outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
    loss = outputs[0]
    loss /= GRAD_ACCUM_STEPS
    loss.backward()
    epoch_loss += loss.item()
    if (i + 1) % GRAD_ACCUM_STEPS == 0:
      optim.step()
      optim.zero_grad()

    loop.set_description(f'Epoch {epoch+1}')
    loop.set_postfix(loss=epoch_loss / (i+1))



Epoch 1: 100%|██████████| 2320/2320 [06:22<00:00,  6.07it/s, loss=0.0107] 
Epoch 2: 100%|██████████| 2320/2320 [06:23<00:00,  6.04it/s, loss=0.00976]
Epoch 3: 100%|██████████| 2320/2320 [06:23<00:00,  6.04it/s, loss=0.0103] 
Epoch 4: 100%|██████████| 2320/2320 [06:23<00:00,  6.04it/s, loss=0.0096] 


In [70]:
model.eval()

total_loss = 0
acc = []

for batch in tqdm(valid_loader):
  with torch.no_grad():
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    start_true = batch['start_positions'].to(device)
    end_true = batch['end_positions'].to(device)
    
    outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_true, end_positions=end_true)

    loss = outputs[0]
    total_loss += loss.item()

    start_pred = torch.argmax(outputs['start_logits'], dim=1)
    end_pred = torch.argmax(outputs['end_logits'], dim=1)

    acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
    acc.append(((end_pred == end_true).sum()/len(end_pred)).item())

acc = sum(acc)/len(acc)
avg_loss = total_loss / len(valid_loader)

print(f"\nValidation Loss: {avg_loss:.4f}")
print(f"Validation Accuracy: {acc:.4f}")
print("\n\nT/P\tanswer_start\tanswer_end\n")
for i in range(len(start_true)):
  print(f"true\t{start_true[i]}\t{end_true[i]}\n"
        f"pred\t{start_pred[i]}\t{end_pred[i]}\n")


100%|██████████| 993/993 [00:53<00:00, 18.56it/s]


Validation Loss: 3.3698
Validation Accuracy: 0.5640


T/P	answer_start	answer_end

true	59	59
pred	54	54

true	59	60
pred	54	54

true	59	59
pred	54	54






In [None]:
# model.eval()

# acc = []

# for batch in tqdm(valid_loader):
#   with torch.no_grad():
#     input_ids = batch['input_ids'].to(device)
#     attention_mask = batch['attention_mask'].to(device)
#     start_true = batch['start_positions'].to(device)
#     end_true = batch['end_positions'].to(device)
    
#     outputs = model(input_ids, attention_mask=attention_mask)

#     start_pred = torch.argmax(outputs['start_logits'], dim=1)
#     end_pred = torch.argmax(outputs['end_logits'], dim=1)

#     acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
#     acc.append(((end_pred == end_true).sum()/len(end_pred)).item())

# acc = sum(acc)/len(acc)

# print("\n\nT/P\tanswer_start\tanswer_end\n")
# for i in range(len(start_true)):
#   print(f"true\t{start_true[i]}\t{end_true[i]}\n"
#         f"pred\t{start_pred[i]}\t{end_pred[i]}\n")
     

In [71]:
# def get_prediction(context, question):
#   inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
#   outputs = model(**inputs)
  
#   answer_start = torch.argmax(outputs[0])  
#   answer_end = torch.argmax(outputs[1]) + 1 
  
#   answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
  
#   return answer

# def normalize_text(s):
#   """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
#   import string, re
#   def remove_articles(text):
#     regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
#     return re.sub(regex, " ", text)
#   def white_space_fix(text):
#     return " ".join(text.split())
#   def remove_punc(text):
#     exclude = set(string.punctuation)
#     return "".join(ch for ch in text if ch not in exclude)
#   def lower(text):
#     return text.lower()

#   return white_space_fix(remove_articles(remove_punc(lower(s))))

# def exact_match(prediction, truth):
#     return bool(normalize_text(prediction) == normalize_text(truth))

# def compute_f1(prediction, truth):
#   pred_tokens = normalize_text(prediction).split()
#   truth_tokens = normalize_text(truth).split()
  
#   # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
#   if len(pred_tokens) == 0 or len(truth_tokens) == 0:
#     return int(pred_tokens == truth_tokens)
  
#   common_tokens = set(pred_tokens) & set(truth_tokens)
  
#   # if there are no common tokens then f1 = 0
#   if len(common_tokens) == 0:
#     return 0
  
#   prec = len(common_tokens) / len(pred_tokens)
#   rec = len(common_tokens) / len(truth_tokens)
  
#   return round(2 * (prec * rec) / (prec + rec), 2)

  
# def question_answer(context, question,answer):
#   prediction = get_prediction(context,question)
#   em_score = exact_match(prediction, answer)
#   f1_score = compute_f1(prediction, answer)

#   print(f'Question: {question}')
#   print(f'Prediction: {prediction}')
#   print(f'True Answer: {answer}')
#   print(f'Exact match: {em_score}')
#   print(f'F1 score: {f1_score}\n')

In [72]:
import jiwer

model.eval()
wer = []
f1_scores=[]
for batch in tqdm(valid_loader):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        
        # Calculate WER
        for i in range(len(start_true)):
            true_text = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]])
            pred_text = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]])
            if true_text.strip() == "":
                continue
            wer.append(jiwer.wer(true_text, pred_text))
            f1_scores.append(compute_f1(true_text, pred_text))
            
        
wer = sum(wer) / len(wer)
print(f'WER: {wer:.4f}')


100%|██████████| 993/993 [00:55<00:00, 17.92it/s]

WER: 1.1681





In [81]:
import numpy as np
overall_f1_score = np.mean(f1_scores)
print(overall_f1_score)

0.5982779567418316


In [82]:
from transformers import AdamW, get_linear_schedule_with_warmup

N_EPOCHS = 4
GRAD_ACCUM_STEPS = 4
MAX_GRAD_NORM = 1.0
WARMUP_STEPS = 500
T_TOTAL = len(train_loader) * N_EPOCHS // GRAD_ACCUM_STEPS

optim = AdamW(model.parameters(), lr=5e-5)

scheduler = get_linear_schedule_with_warmup(
    optim, num_warmup_steps=WARMUP_STEPS, num_training_steps=T_TOTAL
)

model.to(device)
model.train()

for epoch in range(N_EPOCHS):
    loop = tqdm(train_loader, leave=True)
    epoch_loss = 0
    
    for i, batch in enumerate(loop):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        
        model.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        
        loss = outputs[0]
        loss /= GRAD_ACCUM_STEPS
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        
        epoch_loss += loss.item()
        
        if (i + 1) % GRAD_ACCUM_STEPS == 0:
            optim.step()
            scheduler.step()
            optim.zero_grad()

        loop.set_description(f'Epoch {epoch+1}')
        loop.set_postfix(loss=epoch_loss / (i+1))


Epoch 1: 100%|██████████| 2320/2320 [06:35<00:00,  5.86it/s, loss=0.0088] 
Epoch 2: 100%|██████████| 2320/2320 [06:37<00:00,  5.84it/s, loss=0.0156]
Epoch 3: 100%|██████████| 2320/2320 [06:37<00:00,  5.84it/s, loss=0.0114]
Epoch 4: 100%|██████████| 2320/2320 [06:37<00:00,  5.84it/s, loss=0.00802]


In [84]:
import jiwer

model.eval()
acc = []
wer = []
f1_scores=[]
for batch in tqdm(valid_loader):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        
        # Calculate accuracy
        acc.append(((start_pred == start_true).sum() / len(start_pred)).item())
        acc.append(((end_pred == end_true).sum() / len(end_pred)).item())
        
        # Calculate WER
        for i in range(len(start_true)):
            true_text = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]])
            pred_text = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]])
            if true_text.strip() == "":
                continue
            wer.append(jiwer.wer(true_text, pred_text))
            f1_scores.append(compute_f1(true_text, pred_text))
        
acc = sum(acc) / len(acc)
wer = sum(wer) / len(wer)
overall_f1_score = np.mean(f1_scores)
print(f'Accuracy: {acc:.6f}')
print(f'WER: {wer:.6f}')
print(f'F1_Score: {overall_f1_score:.6f}')

100%|██████████| 993/993 [00:55<00:00, 17.75it/s]

Accuracy: 0.565395
WER: 1.367479
F1_Score: 0.611829





In [85]:

print("\n\nT/P\tanswer_start\tanswer_end\n")
for i in range(len(start_true)):
    print(f"true\t{start_true[i]}\t\t\t{end_true[i]}\n"
          f"pred\t{start_pred[i]}\t\t\t{end_pred[i]}")



T/P	answer_start	answer_end

true	59			59
pred	59			54
true	59			60
pred	59			54
true	59			59
pred	59			54


In [86]:
from tabulate import tabulate

In [87]:

# Prepare data for table
data = []
for i in range(len(start_true)):
    data.append(['True', start_true[i], end_true[i]])
    data.append(['Pred', start_pred[i], end_pred[i]])

# Print table
print(tabulate(data, headers=['T/P', 'Answer Start', 'Answer End'], tablefmt='orgtbl'))

| T/P   |   Answer Start |   Answer End |
|-------+----------------+--------------|
| True  |             59 |           59 |
| Pred  |             59 |           54 |
| True  |             59 |           60 |
| Pred  |             59 |           54 |
| True  |             59 |           59 |
| Pred  |             59 |           54 |


In [56]:
import jiwer

model.eval()
acc = []
wer = []
f1_scores=[]
for batch in tqdm(valid_loader):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask)
        
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        
        # Calculate accuracy
        acc.append(((start_pred == start_true).sum() / len(start_pred)).item())
        acc.append(((end_pred == end_true).sum() / len(end_pred)).item())
        
        # Calculate WER
        for i in range(len(start_true)):
            true_text = tokenizer.decode(input_ids[i][start_true[i]:end_true[i]])
            pred_text = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i]])
            if true_text.strip() == "":
                continue
            wer.append(jiwer.wer(true_text, pred_text))
            f1_scores.append(compute_f1(true_text, pred_text))
        
acc = sum(acc) / len(acc)
wer = sum(wer) / len(wer)
overall_f1_score = np.mean(f1_scores)
print(f'Accuracy: {acc:.6f}')
print(f'WER: {wer:.6f}')
print(f'F1_Score: {overall_f1_score:.6f}')

  0%|          | 0/993 [00:00<?, ?it/s]


KeyError: 'start_positions'