In [None]:
%pip install evaluate

# models/get_models.py

In [2]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch

def get_model_and_tokenizer(model_checkpoint: str = "bert-base-cased", device: str = None):
    if device is None:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    
    return model, tokenizer, device

def load_fine_tuned_model(model_path: str, device: str = None):
    if device is None:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    model = AutoModelForQuestionAnswering.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    model.to(device)
    return model, tokenizer, device

# src/utils/metrics.py

In [3]:
from tqdm.auto import tqdm
import collections
import numpy as np
import evaluate

def compute_metrics(start_logits, end_logits, features, examples, n_best=20, max_answer_length=30):
    metric = evaluate.load("squad")
    
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples, desc="Computing metrics"):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

2025-09-10 14:29:03.223548: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757514543.432360      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757514543.492228      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# src/preprocessing.py

In [4]:
from transformers import AutoTokenizer
from datasets import Dataset
from typing import Dict, List, Any
import collections

class DataPreprocessor:
    def __init__(self, model_checkpoint: str = "bert-base-cased", max_length: int = 384, stride: int = 128):
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        self.max_length = max_length
        self.stride = stride
    
    def preprocess_training_examples(self, examples: Dict[str, List]) -> Dict[str, List]:
        questions = [q.strip() for q in examples["question"]]
        inputs = self.tokenizer(
            questions,
            examples["context"],
            max_length=self.max_length,
            truncation="only_second",
            stride=self.stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        offset_mapping = inputs.pop("offset_mapping")
        sample_map = inputs.pop("overflow_to_sample_mapping")
        answers = examples["answers"]
        start_positions = []
        end_positions = []

        for i, offset in enumerate(offset_mapping):
            sample_idx = sample_map[i]
            answer = answers[sample_idx]
            start_char = answer["answer_start"][0]
            end_char = answer["answer_start"][0] + len(answer["text"][0])
            sequence_ids = inputs.sequence_ids(i)

            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            # If the answer is not fully inside the context, label is (0, 0)
            if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                # Otherwise it's the start and end token positions
                idx = context_start
                while idx <= context_end and offset[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)

                idx = context_end
                while idx >= context_start and offset[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)

        inputs["start_positions"] = start_positions
        inputs["end_positions"] = end_positions
        return inputs

    def preprocess_validation_examples(self, examples: Dict[str, List]) -> Dict[str, List]:
        questions = [q.strip() for q in examples["question"]]
        inputs = self.tokenizer(
            questions,
            examples["context"],
            max_length=self.max_length,
            truncation="only_second",
            stride=self.stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        sample_map = inputs.pop("overflow_to_sample_mapping")
        example_ids = []

        for i in range(len(inputs["input_ids"])):
            sample_idx = sample_map[i]
            example_ids.append(examples["id"][sample_idx])

            sequence_ids = inputs.sequence_ids(i)
            offset = inputs["offset_mapping"][i]
            inputs["offset_mapping"][i] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]

        inputs["example_id"] = example_ids
        return inputs

    def prepare_datasets(self, raw_datasets: Dict[str, Dataset]) -> Dict[str, Dataset]:
        train_dataset = raw_datasets["train"].map(
            self.preprocess_training_examples,
            batched=True,
            remove_columns=raw_datasets["train"].column_names,
        )
        
        validation_dataset = raw_datasets["validation"].map(
            self.preprocess_validation_examples,
            batched=True,
            remove_columns=raw_datasets["validation"].column_names,
        )
        
        return {
            "train": train_dataset,
            "validation": validation_dataset
        }

# src/train.py

In [5]:
from accelerate import Accelerator
from transformers import get_scheduler
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import default_data_collator
from tqdm.auto import tqdm
import torch
import numpy as np
from typing import Dict, Any

class Trainer:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.model, self.tokenizer, self.device = get_model_and_tokenizer(
            config.get("model_checkpoint", "bert-base-cased")
        )
        
    def setup_training(self, train_dataset, validation_dataset):
        train_dataset.set_format("torch")
        validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
        validation_set.set_format("torch")

        self.train_dataloader = DataLoader(
            train_dataset,
            shuffle=True,
            collate_fn=default_data_collator,
            batch_size=self.config.get("batch_size", 8),
        )
        self.eval_dataloader = DataLoader(
            validation_set, collate_fn=default_data_collator, batch_size=self.config.get("batch_size", 8)
        )

        self.optimizer = AdamW(self.model.parameters(), lr=self.config.get("learning_rate", 2e-5))
        self.accelerator = Accelerator(mixed_precision=self.config.get("mixed_precision", "fp16"))
        
        self.model, self.optimizer, self.train_dataloader, self.eval_dataloader = self.accelerator.prepare(
            self.model, self.optimizer, self.train_dataloader, self.eval_dataloader
        )

        num_train_epochs = self.config.get("num_epochs", 3)
        num_update_steps_per_epoch = len(self.train_dataloader)
        num_training_steps = num_train_epochs * num_update_steps_per_epoch

        self.lr_scheduler = get_scheduler(
            "linear",
            optimizer=self.optimizer,
            num_warmup_steps=self.config.get("warmup_steps", 0),
            num_training_steps=num_training_steps,
        )
        
        self.output_dir = self.config.get("output_dir", "bert-finetuned-squad-accelerate")
        self.progress_bar = tqdm(range(num_training_steps))
        
    def train_epoch(self, epoch: int):
        self.model.train()
        for step, batch in enumerate(self.train_dataloader):
            outputs = self.model(**batch)
            loss = outputs.loss
            self.accelerator.backward(loss)

            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()
            self.progress_bar.update(1)
            
            if step % self.config.get("logging_steps", 1000) == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
    
    def evaluate(self, validation_dataset, raw_validation_dataset):
        self.model.eval()
        start_logits = []
        end_logits = []
        
        for batch in tqdm(self.eval_dataloader, desc="Evaluating"):
            with torch.no_grad():
                outputs = self.model(**batch)

            start_logits.append(self.accelerator.gather(outputs.start_logits).cpu().numpy())
            end_logits.append(self.accelerator.gather(outputs.end_logits).cpu().numpy())

        start_logits = np.concatenate(start_logits)
        end_logits = np.concatenate(end_logits)
        start_logits = start_logits[: len(validation_dataset)]
        end_logits = end_logits[: len(validation_dataset)]

        metrics = compute_metrics(
            start_logits, end_logits, validation_dataset, raw_validation_dataset
        )
        return metrics
    
    def save_model(self):
        self.accelerator.wait_for_everyone()
        unwrapped_model = self.accelerator.unwrap_model(self.model)
        unwrapped_model.save_pretrained(self.output_dir, save_function=self.accelerator.save)
        if self.accelerator.is_main_process:
            self.tokenizer.save_pretrained(self.output_dir)
    
    def train(self, train_dataset, validation_dataset, raw_validation_dataset):
        self.setup_training(train_dataset, validation_dataset)
        
        for epoch in range(self.config.get("num_epochs", 3)):
            print(f"Starting epoch {epoch + 1}")
            self.train_epoch(epoch)
            
            print("Running evaluation...")
            metrics = self.evaluate(validation_dataset, raw_validation_dataset)
            print(f"Epoch {epoch} metrics:", metrics)
            
            # Save checkpoint after each epoch
            self.save_model()
        
        print("Training completed!")

# tests/evaluate.py

In [6]:
import torch
from transformers import pipeline

def evaluate_model_on_sample(model_path: str, dataset, sample_index: int = 0):
    """Evaluate the model on a single sample from the dataset"""
    model, tokenizer, device = load_fine_tuned_model(model_path)
    
    sample = dataset[sample_index]
    context = sample["context"]
    question = sample["question"]
    
    question_answerer = pipeline("question-answering", model=model_path, tokenizer=tokenizer)
    
    print("Question:", question)
    print("Context:", context[:200], "...")
    
    result = question_answerer(question=question, context=context)
    print("Predicted answer:", result["answer"])
    print("Ground truth:", sample["answers"])
    
    return result

def batch_evaluate(model_path: str, dataset, num_samples: int = 5):
    """Evaluate the model on multiple samples"""
    results = []
    for i in range(min(num_samples, len(dataset))):
        print(f"\n--- Sample {i+1} ---")
        result = evaluate_model_on_sample(model_path, dataset, i)
        results.append(result)
    
    return results

# src/main.py

In [7]:
from datasets import load_dataset

# Load dataset
print("Loading dataset...")
raw_datasets = load_dataset("squad")
    
# Preprocess data
print("Preprocessing data...")
preprocessor = DataPreprocessor(model_checkpoint="bert-base-cased")
processed_datasets = preprocessor.prepare_datasets(raw_datasets)
    
# Training configuration
config = {
        "model_checkpoint": "bert-base-cased",
        "batch_size": 8,
        "learning_rate": 2e-5,
        "num_epochs": 3,
        "output_dir": "bert-finetuned-squad-accelerate",
        "mixed_precision": "fp16",
        "warmup_steps": 0,
        "logging_steps": 100
}

Loading dataset...


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Preprocessing data...


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [8]:
# Train model
print("Starting training...")
trainer = Trainer(config)
trainer.train(
        processed_datasets["train"],
        processed_datasets["validation"],
        raw_datasets["validation"]
)

Starting training...


model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/33276 [00:00<?, ?it/s]

Starting epoch 1
Epoch 0, Step 0, Loss: 5.9636
Epoch 0, Step 100, Loss: 1.9534
Epoch 0, Step 200, Loss: 2.4159
Epoch 0, Step 300, Loss: 1.6487
Epoch 0, Step 400, Loss: 1.6314
Epoch 0, Step 500, Loss: 1.9304
Epoch 0, Step 600, Loss: 0.8628
Epoch 0, Step 700, Loss: 1.4906
Epoch 0, Step 800, Loss: 1.1639
Epoch 0, Step 900, Loss: 0.9081
Epoch 0, Step 1000, Loss: 1.6234
Epoch 0, Step 1100, Loss: 1.0068
Epoch 0, Step 1200, Loss: 1.1190
Epoch 0, Step 1300, Loss: 2.4299
Epoch 0, Step 1400, Loss: 1.1795
Epoch 0, Step 1500, Loss: 0.4383
Epoch 0, Step 1600, Loss: 1.4404
Epoch 0, Step 1700, Loss: 1.6171
Epoch 0, Step 1800, Loss: 1.2412
Epoch 0, Step 1900, Loss: 0.7952
Epoch 0, Step 2000, Loss: 0.9616
Epoch 0, Step 2100, Loss: 1.1313
Epoch 0, Step 2200, Loss: 1.1318
Epoch 0, Step 2300, Loss: 1.4847
Epoch 0, Step 2400, Loss: 0.8708
Epoch 0, Step 2500, Loss: 0.3251
Epoch 0, Step 2600, Loss: 1.1906
Epoch 0, Step 2700, Loss: 1.4669
Epoch 0, Step 2800, Loss: 1.1511
Epoch 0, Step 2900, Loss: 0.9333
Epoch

Evaluating:   0%|          | 0/1353 [00:00<?, ?it/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Computing metrics:   0%|          | 0/10570 [00:00<?, ?it/s]

Epoch 0 metrics: {'exact_match': 79.68779564806054, 'f1': 87.45568714195367}
Starting epoch 2
Epoch 1, Step 0, Loss: 0.7739
Epoch 1, Step 100, Loss: 0.4424
Epoch 1, Step 200, Loss: 0.1529
Epoch 1, Step 300, Loss: 0.8527
Epoch 1, Step 400, Loss: 0.9178
Epoch 1, Step 500, Loss: 1.0688
Epoch 1, Step 600, Loss: 1.3953
Epoch 1, Step 700, Loss: 0.6954
Epoch 1, Step 800, Loss: 0.5115
Epoch 1, Step 900, Loss: 0.6684
Epoch 1, Step 1000, Loss: 0.4616
Epoch 1, Step 1100, Loss: 0.5031
Epoch 1, Step 1200, Loss: 0.4162
Epoch 1, Step 1300, Loss: 0.7286
Epoch 1, Step 1400, Loss: 0.3105
Epoch 1, Step 1500, Loss: 0.0968
Epoch 1, Step 1600, Loss: 0.2278
Epoch 1, Step 1700, Loss: 0.6992
Epoch 1, Step 1800, Loss: 0.7637
Epoch 1, Step 1900, Loss: 0.5951
Epoch 1, Step 2000, Loss: 0.5828
Epoch 1, Step 2100, Loss: 0.3203
Epoch 1, Step 2200, Loss: 0.4728
Epoch 1, Step 2300, Loss: 0.4529
Epoch 1, Step 2400, Loss: 1.4316
Epoch 1, Step 2500, Loss: 0.6368
Epoch 1, Step 2600, Loss: 0.9069
Epoch 1, Step 2700, Loss: 0

Evaluating:   0%|          | 0/1353 [00:00<?, ?it/s]

Computing metrics:   0%|          | 0/10570 [00:00<?, ?it/s]

Epoch 1 metrics: {'exact_match': 80.85146641438033, 'f1': 88.42013128202687}
Starting epoch 3
Epoch 2, Step 0, Loss: 0.9548
Epoch 2, Step 100, Loss: 0.2694
Epoch 2, Step 200, Loss: 0.8582
Epoch 2, Step 300, Loss: 0.2784
Epoch 2, Step 400, Loss: 0.7177
Epoch 2, Step 500, Loss: 0.4609
Epoch 2, Step 600, Loss: 0.1685
Epoch 2, Step 700, Loss: 0.5787
Epoch 2, Step 800, Loss: 0.1948
Epoch 2, Step 900, Loss: 0.4291
Epoch 2, Step 1000, Loss: 0.3784
Epoch 2, Step 1100, Loss: 0.3239
Epoch 2, Step 1200, Loss: 0.8043
Epoch 2, Step 1300, Loss: 0.5281
Epoch 2, Step 1400, Loss: 0.1960
Epoch 2, Step 1500, Loss: 0.8635
Epoch 2, Step 1600, Loss: 0.3148
Epoch 2, Step 1700, Loss: 0.8819
Epoch 2, Step 1800, Loss: 0.4452
Epoch 2, Step 1900, Loss: 0.5352
Epoch 2, Step 2000, Loss: 0.5895
Epoch 2, Step 2100, Loss: 0.4872
Epoch 2, Step 2200, Loss: 0.7874
Epoch 2, Step 2300, Loss: 0.3561
Epoch 2, Step 2400, Loss: 0.5715
Epoch 2, Step 2500, Loss: 0.6476
Epoch 2, Step 2600, Loss: 0.8600
Epoch 2, Step 2700, Loss: 0

Evaluating:   0%|          | 0/1353 [00:00<?, ?it/s]

Computing metrics:   0%|          | 0/10570 [00:00<?, ?it/s]

Epoch 2 metrics: {'exact_match': 80.99337748344371, 'f1': 88.6620765461844}
Training completed!


In [9]:
# Evaluate on a few samples
print("\nEvaluating on sample data...")
batch_evaluate(config["output_dir"], raw_datasets["validation"], num_samples=3)


Evaluating on sample data...

--- Sample 1 ---


Device set to use cuda:0


Question: Which NFL team represented the AFC at Super Bowl 50?
Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated ...
Predicted answer: Denver Broncos
Ground truth: {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}

--- Sample 2 ---


Device set to use cuda:0


Question: Which NFL team represented the NFC at Super Bowl 50?
Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated ...
Predicted answer: Carolina Panthers
Ground truth: {'text': ['Carolina Panthers', 'Carolina Panthers', 'Carolina Panthers'], 'answer_start': [249, 249, 249]}

--- Sample 3 ---


Device set to use cuda:0


Question: Where did Super Bowl 50 take place?
Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated ...
Predicted answer: Levi's Stadium in the San Francisco Bay Area at Santa Clara, California
Ground truth: {'text': ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."], 'answer_start': [403, 355, 355]}


[{'score': 0.9579598903656006,
  'start': 177,
  'end': 191,
  'answer': 'Denver Broncos'},
 {'score': 0.7057539224624634,
  'start': 249,
  'end': 266,
  'answer': 'Carolina Panthers'},
 {'score': 0.2681267559528351,
  'start': 355,
  'end': 426,
  'answer': "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California"}]