In [1]:
from datasets import load_from_disk, load_dataset, DatasetDict
from transformers import (
    TrainingArguments, Trainer, BatchEncoding,
    DistilBertTokenizerFast, DefaultDataCollator, DistilBertForQuestionAnswering
)
import torch
import numpy as np
import contextGenerator
import utils


checkpoint = 'distilbert-base-cased-distilled-squad'
tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint)
contextGen = contextGenerator.LuceneRetrieval()


try:
    ds = load_from_disk('../res/data/guess_train')
except:
    ds = load_dataset("community-datasets/qanta", "mode=first,char_skip=25")


  from .autonotebook import tqdm as notebook_tqdm
Mar 05, 2025 2:52:40 AM org.apache.lucene.store.MemorySegmentIndexInputProvider <init>
INFO: Using MemorySegmentIndexInput with Java 21; to disable start with -Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false


# Preprocessing the data 
Given how BERT is a extractive model it will attempt to highlight its prediction in the provided context. In other words our task is to fine tune the model to predict the start and end positions of the answer in the context.  
#### 1. Retreive context
For each question we will need a relevent document where the answer may exist. 

In [2]:
try: 
    ds['test']['context']
except:
    ds = ds.map(lambda x: {'context':  contextGen(x['full_question'], 1)[0]})

#### 2. Find the start and end postions
The contexts and questions are just strings to so we need to find the positions for the answers in the context. 

In [3]:
try: 
    ds['test']['char_pos']
except:
    ds = ds.map(lambda x: {'char_pos':  utils.term_char_index(x['answer'], x['context']['contents'])})

### 2. Tokenize context/question pair and find the token positions
Ensure the context comes first in the pair to align the character index with the token index. BERT limits the combined token count of context and question to 512. Since the context is capped at 400 words, this won’t cause issues, but we’ll use padding and truncation for consistency and edge cases.

In [4]:
unpack = lambda x, y, z: {"start_positions": x, "end_positions": y, "encodings": z}

def tokenize_row(row: dict, tokenizer) -> dict[str, BatchEncoding]:
    try: 
        encoding =  tokenizer(
            text = row['context']['contents'], 
            text_pair = row['full_question'], 
            padding = 'max_length', 
            truncation = 'only_first', 
            max_length = 512, 
            return_tensors = 'pt', 
            padding_side = 'right',
            return_length = True
            )
    except:
        cleaned = utils.clean_text(row['full_question'])
        encoding =  tokenizer(
            text = row['context']['contents'], 
            text_pair = cleaned, 
            padding = 'max_length', 
            truncation = 'only_first', 
            max_length = 512, 
            return_tensors = 'pt', 
            padding_side = 'right',
            return_length = True
            )
    start_pos = []
    end_pos = []
    for (x, y) in row['char_pos']:
        st = encoding.char_to_token(x)
        ed = encoding.char_to_token(y - 1)
        if st is None or ed is None:
            st = -1
            ed = -1
        start_pos.append(st)
        end_pos.append(ed)
    if len(start_pos) == 0:
        start_pos.append(-1)
        end_pos.append(-1)
    encoding.update({'start_positions': start_pos, 'end_positions': end_pos})
    return {"encodings": encoding}

try: 
    ds['test']['encodings']
except:
    ds = ds.map(lambda x: tokenize_row(x, tokenizer))
    
    

In [5]:
def equ_len_pad(encs): 
    # I want to find the largest list of start positions, from this pa all of rest to meet this size
    longest_len = max([len(x['start_positions']) for x in encs])
    for x in encs: 
        x_len = len(x['start_positions'])
        x['start_positions'] = x['start_positions'] + ([-1 for x in range(longest_len - x_len)])
        x['end_positions'] = x['end_positions'] + ([-1 for x in range(longest_len - x_len)])
        x['input_ids'] =  x['input_ids'][0]
        x['attention_mask'] = x['attention_mask'][0]
    return encs

train = equ_len_pad(ds['train']['encodings'])
val = equ_len_pad(ds['val']['encodings'])
test = equ_len_pad(ds['test']['encodings'])


In [6]:
# guessTrain = DatasetDict({
#     'train': ds['guesstrain'],
#     'val': ds['guessdev'],
#     'test': ds['guesstest'],
# })

# train = ds['train']
# val = ds['val']
# test = ds['test']
# ds.save_to_disk('../res/data/guess_train')

# Training

In [7]:
class BartTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # generate model's guess
        outputs = model(input_ids = inputs['input_ids'] , attention_mask = inputs['attention_mask'])
        # find the model's predictions 
        start_yhat = torch.argmax(outputs['start_logits'], dim= 1)
        end_yhat = torch.argmax(outputs['end_logits'], dim= 1)
        # check if this is a possible target
        # if it is set it as the target, else choose a random valid target 
        start_target = []
        end_target = []

        valid_start_targets = []
        valid_end_targets = []

        for idx, x in enumerate(inputs['start_positions']): 
            cur_start_targ = []
            cur_end_targ = []

            for idx2, y in enumerate(x): 
                if y != -1:
                    cur_start_targ.append(y)
                    cur_end_targ.append(inputs['end_positions'][idx][idx2])

                else: 
                    break 
                
                
            valid_start_targets.append(cur_start_targ if cur_start_targ else [-1])
            valid_end_targets.append(cur_end_targ if cur_end_targ else [-1])          
            
        for x in range(len(inputs['input_ids'])):
            if start_yhat[x] in valid_start_targets[x]:
                start_target.append(start_yhat[x])                
                end_target.append(end_yhat[x])
            else: 
                ran_int = np.random.randint(len(valid_start_targets[x]))
                s_rand = valid_start_targets[x][ran_int]
                e_rand = valid_end_targets[x][ran_int]
     
                start_target.append(s_rand)
                end_target.append(e_rand)

        
        device = outputs['start_logits'].device
        start_target = torch.tensor(start_target, dtype=torch.long, device=device)
        end_target = torch.tensor(end_target, dtype=torch.long, device=device)

        # Compute loss
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1)
        start_loss = loss_fct(outputs['start_logits'], start_target)
        end_loss = loss_fct(outputs['end_logits'], end_target)
        
        total_loss = (start_loss + end_loss) / 2
        return (total_loss, outputs) if return_outputs else total_loss


In [None]:
tokenizer = DistilBertTokenizerFast.from_pretrained(checkpoint)
model = DistilBertForQuestionAnswering.from_pretrained(checkpoint)

if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")
    model.to(mps_device)


training_args = TrainingArguments(
    output_dir="res/models/guess_v1",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False
)

trainer = BartTrainer(
    model=model,
    args=training_args,                 
    train_dataset=train,
    eval_dataset=test,  
    data_collator=DefaultDataCollator(),  
    processing_class=tokenizer,    
)


trainer.train()

  2%|▏         | 500/30070 [09:01<8:44:21,  1.06s/it]

{'loss': 1.207, 'grad_norm': 2.1110000610351562, 'learning_rate': 4.9168606584635854e-05, 'epoch': 0.08}


  3%|▎         | 1000/30070 [17:57<8:36:13,  1.07s/it]

{'loss': 0.8761, 'grad_norm': 13.238821029663086, 'learning_rate': 4.83372131692717e-05, 'epoch': 0.17}


  5%|▍         | 1500/30070 [26:51<8:26:59,  1.06s/it]

{'loss': 0.7586, 'grad_norm': 24.95265769958496, 'learning_rate': 4.750581975390755e-05, 'epoch': 0.25}


  7%|▋         | 2000/30070 [35:46<8:17:22,  1.06s/it]

{'loss': 0.7675, 'grad_norm': 18.301746368408203, 'learning_rate': 4.66744263385434e-05, 'epoch': 0.33}


  8%|▊         | 2500/30070 [44:40<8:08:57,  1.06s/it]

{'loss': 0.7373, 'grad_norm': 12.275507926940918, 'learning_rate': 4.584303292317925e-05, 'epoch': 0.42}


 10%|▉         | 3000/30070 [53:35<8:03:39,  1.07s/it]

{'loss': 0.6985, 'grad_norm': 8.273239135742188, 'learning_rate': 4.50116395078151e-05, 'epoch': 0.5}


 12%|█▏        | 3500/30070 [1:02:30<7:41:40,  1.04s/it]

{'loss': 0.6947, 'grad_norm': 6.774771690368652, 'learning_rate': 4.418024609245095e-05, 'epoch': 0.58}


 13%|█▎        | 4000/30070 [1:11:15<7:40:45,  1.06s/it]

{'loss': 0.6681, 'grad_norm': 11.478196144104004, 'learning_rate': 4.33488526770868e-05, 'epoch': 0.67}


 15%|█▍        | 4500/30070 [1:20:12<7:41:16,  1.08s/it]

{'loss': 0.6419, 'grad_norm': 13.261557579040527, 'learning_rate': 4.2517459261722646e-05, 'epoch': 0.75}


 17%|█▋        | 5000/30070 [1:29:11<7:29:29,  1.08s/it]

{'loss': 0.6285, 'grad_norm': 3.608224868774414, 'learning_rate': 4.16860658463585e-05, 'epoch': 0.83}


 18%|█▊        | 5500/30070 [1:38:07<7:16:06,  1.06s/it]

{'loss': 0.6507, 'grad_norm': 7.499372482299805, 'learning_rate': 4.085467243099435e-05, 'epoch': 0.91}


 20%|█▉        | 6000/30070 [1:47:01<7:12:47,  1.08s/it]

{'loss': 0.6212, 'grad_norm': 10.14529037475586, 'learning_rate': 4.00232790156302e-05, 'epoch': 1.0}


                                                        
 20%|██        | 6014/30070 [1:48:05<7:09:07,  1.07s/it]

{'eval_loss': 0.8957862257957458, 'eval_runtime': 49.7146, 'eval_samples_per_second': 43.267, 'eval_steps_per_second': 2.716, 'epoch': 1.0}


 22%|██▏       | 6500/30070 [2:13:21<11:33:09,  1.76s/it] 

{'loss': 0.4573, 'grad_norm': 10.107312202453613, 'learning_rate': 3.919188560026605e-05, 'epoch': 1.08}


 23%|██▎       | 7000/30070 [2:29:28<12:05:37,  1.89s/it]

{'loss': 0.427, 'grad_norm': 0.7754099369049072, 'learning_rate': 3.8360492184901896e-05, 'epoch': 1.16}


 25%|██▍       | 7500/30070 [2:46:12<13:06:14,  2.09s/it]

{'loss': 0.4485, 'grad_norm': 0.15558142960071564, 'learning_rate': 3.752909876953775e-05, 'epoch': 1.25}


 27%|██▋       | 8000/30070 [3:02:25<11:06:01,  1.81s/it]

{'loss': 0.4627, 'grad_norm': 21.465978622436523, 'learning_rate': 3.669770535417359e-05, 'epoch': 1.33}


 28%|██▊       | 8500/30070 [3:20:39<11:34:59,  1.93s/it]

{'loss': 0.4625, 'grad_norm': 20.798330307006836, 'learning_rate': 3.5866311938809444e-05, 'epoch': 1.41}


 30%|██▉       | 9000/30070 [3:37:28<13:35:46,  2.32s/it]

{'loss': 0.4284, 'grad_norm': 7.976465702056885, 'learning_rate': 3.5034918523445296e-05, 'epoch': 1.5}


 32%|███▏      | 9500/30070 [3:54:25<12:08:10,  2.12s/it]

{'loss': 0.4313, 'grad_norm': 6.569569110870361, 'learning_rate': 3.420352510808115e-05, 'epoch': 1.58}


 33%|███▎      | 10000/30070 [4:11:12<11:12:33,  2.01s/it]

{'loss': 0.4302, 'grad_norm': 5.676962852478027, 'learning_rate': 3.3372131692717e-05, 'epoch': 1.66}


 35%|███▍      | 10500/30070 [4:27:35<11:10:21,  2.06s/it]

{'loss': 0.4427, 'grad_norm': 7.371875286102295, 'learning_rate': 3.254073827735284e-05, 'epoch': 1.75}


 37%|███▋      | 11000/30070 [4:44:27<11:00:30,  2.08s/it]

{'loss': 0.4151, 'grad_norm': 7.575404167175293, 'learning_rate': 3.1709344861988695e-05, 'epoch': 1.83}


 38%|███▊      | 11500/30070 [5:01:01<11:05:51,  2.15s/it]

{'loss': 0.4083, 'grad_norm': 5.982151031494141, 'learning_rate': 3.087795144662454e-05, 'epoch': 1.91}


 40%|███▉      | 12000/30070 [5:18:06<9:44:02,  1.94s/it] 

{'loss': 0.3801, 'grad_norm': 9.37171459197998, 'learning_rate': 3.004655803126039e-05, 'epoch': 2.0}


                                                          
 40%|████      | 12028/30070 [5:19:53<9:58:39,  1.99s/it]

{'eval_loss': 0.8897112011909485, 'eval_runtime': 49.037, 'eval_samples_per_second': 43.865, 'eval_steps_per_second': 2.753, 'epoch': 2.0}


 42%|████▏     | 12500/30070 [5:35:30<9:49:25,  2.01s/it] 

{'loss': 0.2844, 'grad_norm': 0.8601583242416382, 'learning_rate': 2.9215164615896246e-05, 'epoch': 2.08}


 43%|████▎     | 13000/30070 [5:52:10<8:41:55,  1.83s/it] 

{'loss': 0.2765, 'grad_norm': 9.059704780578613, 'learning_rate': 2.8383771200532094e-05, 'epoch': 2.16}


 45%|████▍     | 13500/30070 [6:09:16<8:48:09,  1.91s/it] 

{'loss': 0.2976, 'grad_norm': 7.325918197631836, 'learning_rate': 2.7552377785167942e-05, 'epoch': 2.24}


 47%|████▋     | 14000/30070 [6:26:01<8:34:46,  1.92s/it] 

{'loss': 0.3038, 'grad_norm': 74.92426300048828, 'learning_rate': 2.6720984369803794e-05, 'epoch': 2.33}


 48%|████▊     | 14500/30070 [6:44:35<29:27:37,  6.81s/it]

{'loss': 0.293, 'grad_norm': 17.987077713012695, 'learning_rate': 2.5889590954439642e-05, 'epoch': 2.41}


 50%|████▉     | 15000/30070 [7:26:36<20:32:40,  4.91s/it]

{'loss': 0.2696, 'grad_norm': 1.9816105365753174, 'learning_rate': 2.505819753907549e-05, 'epoch': 2.49}


 52%|█████▏    | 15500/30070 [8:09:28<16:15:47,  4.02s/it]

{'loss': 0.3027, 'grad_norm': 14.505837440490723, 'learning_rate': 2.422680412371134e-05, 'epoch': 2.58}


 53%|█████▎    | 16000/30070 [8:36:45<8:34:08,  2.19s/it] 

{'loss': 0.2982, 'grad_norm': 1.4382952451705933, 'learning_rate': 2.339541070834719e-05, 'epoch': 2.66}


 55%|█████▍    | 16500/30070 [8:57:04<19:23:48,  5.15s/it]

{'loss': 0.2703, 'grad_norm': 12.210532188415527, 'learning_rate': 2.256401729298304e-05, 'epoch': 2.74}


 57%|█████▋    | 17000/30070 [9:34:53<21:31:07,  5.93s/it]

{'loss': 0.3074, 'grad_norm': 14.998262405395508, 'learning_rate': 2.173262387761889e-05, 'epoch': 2.83}


 57%|█████▋    | 17154/30070 [9:46:11<8:43:58,  2.43s/it] 