In [None]:
%pip install evaluate

# models/get_models.py

In [None]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch

def get_model_and_tokenizer(model_checkpoint: str = "bert-base-cased", device: str = None):
    if device is None:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    
    return model, tokenizer, device

def load_fine_tuned_model(model_path: str, device: str = None):
    if device is None:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    model = AutoModelForQuestionAnswering.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    model.to(device)
    return model, tokenizer, device

# src/utils/metrics.py

In [None]:
from tqdm.auto import tqdm
import collections
import numpy as np
import evaluate

def compute_metrics(start_logits, end_logits, features, examples, n_best=20, max_answer_length=30):
    metric = evaluate.load("squad")
    
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples, desc="Computing metrics"):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

# src/preprocessing.py

In [None]:
from transformers import AutoTokenizer
from datasets import Dataset
from typing import Dict, List, Any
import collections

class DataPreprocessor:
    def __init__(self, model_checkpoint: str = "bert-base-cased", max_length: int = 384, stride: int = 128):
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        self.max_length = max_length
        self.stride = stride
    
    def preprocess_training_examples(self, examples: Dict[str, List]) -> Dict[str, List]:
        questions = [q.strip() for q in examples["question"]]
        inputs = self.tokenizer(
            questions,
            examples["context"],
            max_length=self.max_length,
            truncation="only_second",
            stride=self.stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        offset_mapping = inputs.pop("offset_mapping")
        sample_map = inputs.pop("overflow_to_sample_mapping")
        answers = examples["answers"]
        start_positions = []
        end_positions = []

        for i, offset in enumerate(offset_mapping):
            sample_idx = sample_map[i]
            answer = answers[sample_idx]
            start_char = answer["answer_start"][0]
            end_char = answer["answer_start"][0] + len(answer["text"][0])
            sequence_ids = inputs.sequence_ids(i)

            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            # If the answer is not fully inside the context, label is (0, 0)
            if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                # Otherwise it's the start and end token positions
                idx = context_start
                while idx <= context_end and offset[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)

                idx = context_end
                while idx >= context_start and offset[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)

        inputs["start_positions"] = start_positions
        inputs["end_positions"] = end_positions
        return inputs

    def preprocess_validation_examples(self, examples: Dict[str, List]) -> Dict[str, List]:
        questions = [q.strip() for q in examples["question"]]
        inputs = self.tokenizer(
            questions,
            examples["context"],
            max_length=self.max_length,
            truncation="only_second",
            stride=self.stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        sample_map = inputs.pop("overflow_to_sample_mapping")
        example_ids = []

        for i in range(len(inputs["input_ids"])):
            sample_idx = sample_map[i]
            example_ids.append(examples["id"][sample_idx])

            sequence_ids = inputs.sequence_ids(i)
            offset = inputs["offset_mapping"][i]
            inputs["offset_mapping"][i] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]

        inputs["example_id"] = example_ids
        return inputs

    def prepare_datasets(self, raw_datasets: Dict[str, Dataset]) -> Dict[str, Dataset]:
        train_dataset = raw_datasets["train"].map(
            self.preprocess_training_examples,
            batched=True,
            remove_columns=raw_datasets["train"].column_names,
        )
        
        validation_dataset = raw_datasets["validation"].map(
            self.preprocess_validation_examples,
            batched=True,
            remove_columns=raw_datasets["validation"].column_names,
        )
        
        return {
            "train": train_dataset,
            "validation": validation_dataset
        }

# src/train.py

In [None]:
from accelerate import Accelerator
from transformers import get_scheduler
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import default_data_collator
from tqdm.auto import tqdm
import torch
import numpy as np
from typing import Dict, Any

class Trainer:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.model, self.tokenizer, self.device = get_model_and_tokenizer(
            config.get("model_checkpoint", "bert-base-cased")
        )
        
    def setup_training(self, train_dataset, validation_dataset):
        train_dataset.set_format("torch")
        validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
        validation_set.set_format("torch")

        self.train_dataloader = DataLoader(
            train_dataset,
            shuffle=True,
            collate_fn=default_data_collator,
            batch_size=self.config.get("batch_size", 8),
        )
        self.eval_dataloader = DataLoader(
            validation_set, collate_fn=default_data_collator, batch_size=self.config.get("batch_size", 8)
        )

        self.optimizer = AdamW(self.model.parameters(), lr=self.config.get("learning_rate", 2e-5))
        self.accelerator = Accelerator(mixed_precision=self.config.get("mixed_precision", "fp16"))
        
        self.model, self.optimizer, self.train_dataloader, self.eval_dataloader = self.accelerator.prepare(
            self.model, self.optimizer, self.train_dataloader, self.eval_dataloader
        )

        num_train_epochs = self.config.get("num_epochs", 3)
        num_update_steps_per_epoch = len(self.train_dataloader)
        num_training_steps = num_train_epochs * num_update_steps_per_epoch

        self.lr_scheduler = get_scheduler(
            "linear",
            optimizer=self.optimizer,
            num_warmup_steps=self.config.get("warmup_steps", 0),
            num_training_steps=num_training_steps,
        )
        
        self.output_dir = self.config.get("output_dir", "bert-finetuned-squad-accelerate")
        self.progress_bar = tqdm(range(num_training_steps))
        
    def train_epoch(self, epoch: int):
        self.model.train()
        for step, batch in enumerate(self.train_dataloader):
            outputs = self.model(**batch)
            loss = outputs.loss
            self.accelerator.backward(loss)

            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()
            self.progress_bar.update(1)
            
            if step % self.config.get("logging_steps", 100) == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
    
    def evaluate(self, validation_dataset, raw_validation_dataset):
        self.model.eval()
        start_logits = []
        end_logits = []
        
        for batch in tqdm(self.eval_dataloader, desc="Evaluating"):
            with torch.no_grad():
                outputs = self.model(**batch)

            start_logits.append(self.accelerator.gather(outputs.start_logits).cpu().numpy())
            end_logits.append(self.accelerator.gather(outputs.end_logits).cpu().numpy())

        start_logits = np.concatenate(start_logits)
        end_logits = np.concatenate(end_logits)
        start_logits = start_logits[: len(validation_dataset)]
        end_logits = end_logits[: len(validation_dataset)]

        metrics = compute_metrics(
            start_logits, end_logits, validation_dataset, raw_validation_dataset
        )
        return metrics
    
    def save_model(self):
        self.accelerator.wait_for_everyone()
        unwrapped_model = self.accelerator.unwrap_model(self.model)
        unwrapped_model.save_pretrained(self.output_dir, save_function=self.accelerator.save)
        if self.accelerator.is_main_process:
            self.tokenizer.save_pretrained(self.output_dir)
    
    def train(self, train_dataset, validation_dataset, raw_validation_dataset):
        self.setup_training(train_dataset, validation_dataset)
        
        for epoch in range(self.config.get("num_epochs", 3)):
            print(f"Starting epoch {epoch + 1}")
            self.train_epoch(epoch)
            
            print("Running evaluation...")
            metrics = self.evaluate(validation_dataset, raw_validation_dataset)
            print(f"Epoch {epoch} metrics:", metrics)
            
            # Save checkpoint after each epoch
            self.save_model()
        
        print("Training completed!")

# tests/evaluate.py

In [None]:
import torch
from transformers import pipeline

def evaluate_model_on_sample(model_path: str, dataset, sample_index: int = 0):
    """Evaluate the model on a single sample from the dataset"""
    model, tokenizer, device = load_fine_tuned_model(model_path)
    
    sample = dataset[sample_index]
    context = sample["context"]
    question = sample["question"]
    
    question_answerer = pipeline("question-answering", model=model_path, tokenizer=tokenizer)
    
    print("Question:", question)
    print("Context:", context[:200], "...")
    
    result = question_answerer(question=question, context=context)
    print("Predicted answer:", result["answer"])
    print("Ground truth:", sample["answers"])
    
    return result

def batch_evaluate(model_path: str, dataset, num_samples: int = 5):
    """Evaluate the model on multiple samples"""
    results = []
    for i in range(min(num_samples, len(dataset))):
        print(f"\n--- Sample {i+1} ---")
        result = evaluate_model_on_sample(model_path, dataset, i)
        results.append(result)
    
    return results

# src/main.py

In [None]:
from datasets import load_dataset

def main():
    # Load dataset
    print("Loading dataset...")
    raw_datasets = load_dataset("squad")
    
    # Preprocess data
    print("Preprocessing data...")
    preprocessor = DataPreprocessor(model_checkpoint="bert-base-cased")
    processed_datasets = preprocessor.prepare_datasets(raw_datasets)
    
    # Training configuration
    config = {
        "model_checkpoint": "bert-base-cased",
        "batch_size": 8,
        "learning_rate": 2e-5,
        "num_epochs": 3,
        "output_dir": "bert-finetuned-squad-accelerate",
        "mixed_precision": "fp16",
        "warmup_steps": 0,
        "logging_steps": 100
    }
    
    # Train model
    print("Starting training...")
    trainer = Trainer(config)
    trainer.train(
        processed_datasets["train"],
        processed_datasets["validation"],
        raw_datasets["validation"]
    )
    
    # Evaluate on a few samples
    print("\nEvaluating on sample data...")
    batch_evaluate(config["output_dir"], raw_datasets["validation"], num_samples=3)

if __name__ == "__main__":
    main()