In [1]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [2]:
%pip install torchinfo
%restart_python

Note: you may need to restart the kernel to use updated packages.


UsageError: Line magic function `%restart_python` not found.


In [1]:
import os

from bacp import BaCPLearner, BaCPTrainer, BaCPTrainingArgumentsLLM
from models import EncoderProjectionNetwork, ClassificationNetwork
from unstructured_pruning import MagnitudePrune, MovementPrune, LocalMagnitudePrune, LocalMovementPrune, WandaPrune, PRUNER_DICT, check_model_sparsity
from LLM_trainer import LLMTrainer, LLMTrainingArguments
from dataset_utils import get_glue_data
from logger import Logger

import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset 

from tqdm import tqdm
from torchinfo import summary

from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
os.environ["HF_DATASETS_CACHE"] = "/dbfs/hf_datasets"
os.environ["TOKENIZERS_PARALLELISM"] = "false" 

from utils import *
from constants import *

device = get_device()
print(f"{device = }")
BATCH_SIZE_DISTILBERT = 64
NUM_WORKERS = 24


  from .autonotebook import tqdm as notebook_tqdm


device = 'cpu'


In [2]:
dataset = load_dataset("rajpurkar/squad")

In [3]:
print(dataset)
trainset = dataset['train']
valset = dataset['validation']

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})


In [4]:
trainset['answers'][0]

{'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}

In [5]:
from transformers import AutoTokenizer, DistilBertForQuestionAnswering
import torch

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

question = "Who was Jim Henson?"
context = "Jim Henson was a nice puppet creator."
answer = "nice puppet"
start_char = context.index(answer)
end_char = start_char + len(answer)

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
def tokenize_fn(example):
    return tokenizer(
        example['question'],
        example['context'],
        padding='max_length',
        truncation=True,
        return_offsets_mapping=True,
        return_tensors='pt'
    )

# Prepare example
example = {
    "question": question,
    "context": context,
    "answer": answer
}

inputs = tokenize_fn(example)

# Extract and remove offset_mapping (needed for token-char alignment)
offset_mapping = inputs.pop("offset_mapping")[0]
input_ids = inputs["input_ids"][0]

# Get character-level start and end
start_char = context.index(answer)
end_char = start_char + len(answer)

# Map character positions to token indices
start_token = end_token = None
for idx, (start, end) in enumerate(offset_mapping):
    if start_token is None and start <= start_char < end:
        start_token = idx
    if start < end_char <= end:
        end_token = idx

# Fallback in case alignment fails
if start_token is None or end_token is None:
    raise ValueError("Could not find token span for the answer.")

# Inference
with torch.no_grad():
    outputs = model(**inputs)

answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()

# Decode predicted answer
predict_answer_tokens = input_ids[answer_start_index : answer_end_index + 1]
print("Predicted:", tokenizer.decode(predict_answer_tokens, skip_special_tokens=True))

# Add ground truth span for loss computation
inputs["start_positions"] = torch.tensor([start_token])
inputs["end_positions"] = torch.tensor([end_token])

# Training: forward pass with gold spans
outputs = model(**inputs)
loss = outputs.loss
print("Loss:", round(loss.item(), 2))


Predicted: who was jim henson? jim henson was a nice puppet creator.
Loss: 6.26


### DataLoader

In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase

class SquadDataset(Dataset):
    def __init__(self, encodings, answer_starts, answer_ends):

        self.encodings = {k: (v if isinstance(v, torch.Tensor) else torch.tensor(v, dtype=torch.long))
                          for k, v in encodings.items()}
        self.answer_starts = torch.tensor(answer_starts, dtype=torch.long)
        self.answer_ends = torch.tensor(answer_ends, dtype=torch.long)

    def __getitem__(self, idx):
        item = {k: v[idx].clone().detach() for k, v in self.encodings.items()}
        item['start_positions'] = self.answer_starts[idx].clone().detach()
        item['end_positions'] = self.answer_ends[idx].clone().detach()
        return item

    def __len__(self):
        return self.answer_starts.size(0)


def get_squad_data(
    tokenizer: PreTrainedTokenizerBase,
    batch_size: int,
    num_workers: int = 0,
    max_length: int = 384,
    doc_stride: int = 128,
    subset_size: int = None
):

    raw = load_dataset("rajpurkar/squad")

    # Optional subset for fast iterations
    if subset_size:
        for split in ('train', 'validation'):
            if split in raw:
                raw[split] = raw[split].shuffle(seed=42).select(range(min(subset_size, len(raw[split]))))

    def preprocess(examples):
        questions = [q.strip() for q in examples['question']]
        tokenized = tokenizer(
            questions,
            examples['context'],
            truncation='only_second',
            max_length=max_length,
            stride=doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding='max_length'
        )
        sample_map = tokenized.pop('overflow_to_sample_mapping')
        offsets = tokenized.pop('offset_mapping')

        starts, ends = [], []
        for i, offset in enumerate(offsets):
            input_ids = tokenized['input_ids'][i]
            cls_idx = input_ids.index(tokenizer.cls_token_id)
            seq_ids = tokenized.sequence_ids(i)
            sample_idx = sample_map[i]
            answer = examples['answers'][sample_idx]
            if len(answer['answer_start']) == 0:
                starts.append(cls_idx)
                ends.append(cls_idx)
            else:
                start_char = answer['answer_start'][0]
                end_char = start_char + len(answer['text'][0])
                # find window token indices
                token_start, token_end = 0, len(input_ids) - 1
                while seq_ids[token_start] != 1:
                    token_start += 1
                while seq_ids[token_end] != 1:
                    token_end -= 1
                # answer out of window
                if not (offset[token_start][0] <= start_char and offset[token_end][1] >= end_char):
                    starts.append(cls_idx)
                    ends.append(cls_idx)
                else:
                    # move to exact token boundaries
                    while token_start < len(offset) and offset[token_start][0] <= start_char:
                        token_start += 1
                    starts.append(token_start - 1)
                    while offset[token_end][1] >= end_char:
                        token_end -= 1
                    ends.append(token_end + 1)
        tokenized['start_positions'] = starts
        tokenized['end_positions'] = ends
        return tokenized

    datasets = {}
    dataloaders = {}
    for split in ('train', 'validation'):
        if split in raw:
            proc = raw[split].map(
                preprocess,
                batched=True,
                remove_columns=raw[split].column_names
            )
            proc.set_format(type='torch')
            encodings = {
                'input_ids': proc['input_ids'],
                'attention_mask': proc['attention_mask']
            }
            ds = SquadDataset(encodings, proc['start_positions'], proc['end_positions'])
            loader = DataLoader(
                ds,
                batch_size=batch_size,
                shuffle=(split=='train'),
                num_workers=num_workers,
                pin_memory=True,
                persistent_workers=(num_workers>0),
                drop_last=(split=='train')
            )
            datasets[f'{split}_dataset'] = ds
            dataloaders[f'{split}_loader'] = loader
    # No test split available
    dataloaders['test_loader'] = None
    datasets['test_dataset'] = None

    return {**datasets, **dataloaders}


In [8]:
# Initialize tokenizer (you already have this)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Test the function
squad_data = get_squad_data(
    tokenizer=tokenizer,
    batch_size=4,      # Small batch size for testing
    num_workers=0,     # Simpler for initial testing
    subset_size=100    # Use a small subset for faster testing
)

# Check the train loader
train_loader = squad_data["train_loader"]
if train_loader is not None:
    for batch in train_loader:
        print("Batch from train_loader:")
        print("  input_ids shape:     ", batch['input_ids'].shape)
        print("  attention_mask shape:", batch['attention_mask'].shape)
        print("  start_positions shape:", batch['start_positions'].shape)
        print("  end_positions shape:  ", batch['end_positions'].shape)
        break
else:
    print("Train loader is None.")

# Check the validation loader
val_loader = squad_data["validation_loader"]
if val_loader is not None:
    for batch in val_loader:
        print("\nBatch from validation_loader:")
        print("  input_ids shape:     ", batch['input_ids'].shape)
        print("  attention_mask shape:", batch['attention_mask'].shape)
        print("  start_positions shape:", batch['start_positions'].shape)
        print("  end_positions shape:  ", batch['end_positions'].shape)
        break
else:
    print("Validation loader is None.")


Batch from train_loader:
  input_ids shape:      torch.Size([4, 384])
  attention_mask shape: torch.Size([4, 384])
  start_positions shape: torch.Size([4])
  end_positions shape:   torch.Size([4])

Batch from validation_loader:
  input_ids shape:      torch.Size([4, 384])
  attention_mask shape: torch.Size([4, 384])
  start_positions shape: torch.Size([4])
  end_positions shape:   torch.Size([4])


  self.answer_starts = torch.tensor(answer_starts, dtype=torch.long)
  self.answer_ends = torch.tensor(answer_ends, dtype=torch.long)


In [None]:
# # Use this as an example to make squad dataset.
# # It wont follow the same logic, so dont copy paste it.

# class GlueDataset(Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         example = self.data[idx]
#         return {
#             "input_ids": example["input_ids"],
#             "attention_mask": example["attention_mask"],
#             "labels": example["label"]
#         }

# def get_glue_data(model_name, tokenizer, task_name, batch_size, num_workers=24):
#     assert task_name in ["mnli", "qqp", "sst2"], f"Unsupported task: {task_name}"
#     dataset = load_dataset("glue", task_name, cache_dir="/dbfs/hf_datasets")
#     print(f"[DATALOADERS] {[key for key in dataset]}")

#     def tokenize_fn(example):
#         if task_name == "mnli":
#             return tokenizer(example["premise"], example["hypothesis"], truncation=True, padding="max_length")
#         if task_name == "qqp":
#             return tokenizer(example["question1"], example["question2"], truncation=True, padding="max_length")
#         if task_name == "sst2":
#             return tokenizer(example["sentence"], truncation=True, padding="max_length")
        
#     dataset = dataset.map(tokenize_fn, batched=True, batch_size=512, num_proc=1)
#     dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

#     trainset = GlueDataset(dataset["train"])
#     valset = GlueDataset(dataset["validation"])
#     testset = GlueDataset(dataset["test"])

#     loader_args = {
#         "batch_size" : batch_size,
#         "num_workers" : num_workers,
#         "pin_memory" : True,
#         "persistent_workers" : num_workers > 0,
#         "drop_last" : True,
#     }

#     trainloader = DataLoader(dataset["train"], shuffle=True, **loader_args)
#     validationloader = DataLoader(dataset["validation"], **loader_args)
#     testloader = DataLoader(dataset["test"], **loader_args)
  
#     data = {
#         "trainloader": trainloader,
#         "trainset": trainset,
#         "valloader": validationloader,
#         "valset": valset,
#         "testloader": testloader,
#         "testset": testset
#     }
#     return data

In [None]:
# def get_squad_data(tokenizer, task_name, batch_size, num_workers=24):
#     dataset = load_dataset("rajpurkar/squad")
#     print(f"[DATALOADERS] {[key for key in dataset]}")

#     # Tokenize inputs here - dekhlena how to do it
#     # There are 2 sets - train and validation. Their size is too big, add some parameter to take a random subset.
#     # tokenize it properly and return datasets and dataloaders.


### Create Training Script Here:

In [19]:
from transformers import AutoTokenizer, DistilBertForQuestionAnswering, get_linear_schedule_with_warmup
def train(
    model,
    tokenizer,
    train_loader,
    validation_loader=None,
    output_dir="./squad_model",
    num_epochs=3,
    lr=5e-5,
    device=None
):
    
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Training on device: {device}")

    total_steps = len(train_loader) * num_epochs
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    # Epoch loop
    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0

        batchloader = tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False)
        for step, batch in enumerate(batchloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                start_positions=batch['start_positions'],
                end_positions=batch['end_positions']
            )
            loss = outputs.loss
            running_loss += loss.item() / (step + 1)

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            batchloader.set_postfix(Running_Loss=f"{running_loss:.5f}")

        avg_train_loss = running_loss
        print(f"Epoch {epoch} training loss: {avg_train_loss:.4f}")

        # Validation
        if validation_loader:
            model.eval()
            val_loss = 0.0
            for batch in tqdm(validation_loader, desc=f"Epoch {epoch} Validation", leave=False):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = model(
                        input_ids=batch['input_ids'],
                        attention_mask=batch['attention_mask'],
                        start_positions=batch['start_positions'],
                        end_positions=batch['end_positions']
                    )
                val_loss += outputs.loss.item()
            avg_val_loss = val_loss / len(validation_loader)
            print(f"Epoch {epoch} validation loss: {avg_val_loss:.4f}")
            model.train()

        # Save checkpoint
        os.makedirs(output_dir, exist_ok=True)
        ckpt_dir = os.path.join(output_dir, f"checkpoint-epoch{epoch}")
        model.save_pretrained(ckpt_dir)
        tokenizer.save_pretrained(ckpt_dir)
        print(f"Saved checkpoint for epoch {epoch} at {ckpt_dir}")

    # Final save
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Training complete. Final model saved to {output_dir}")


In [20]:
train(
    model,
    tokenizer,
    train_loader,
    validation_loader=None,
    output_dir="./squad_model",
    num_epochs=3,
    lr=5e-5,
    device=None
)

Training on device: cpu


                                                                                        

Epoch 1 training loss: 20.3418
Saved checkpoint for epoch 1 at ./squad_model\checkpoint-epoch1


                                                                                        

Epoch 2 training loss: 13.6800
Saved checkpoint for epoch 2 at ./squad_model\checkpoint-epoch2


                                                                                        

Epoch 3 training loss: 10.0323
Saved checkpoint for epoch 3 at ./squad_model\checkpoint-epoch3
Training complete. Final model saved to ./squad_model


In [21]:
question = "Who was Jim Henson?"
context = "Jim Henson was a nice puppet creator."
answer = "nice puppet"
start_char = context.index(answer)
end_char = start_char + len(answer)


def tokenize_fn(example):
    return tokenizer(
        example['question'],
        example['context'],
        padding='max_length',
        truncation=True,
        return_offsets_mapping=True,
        return_tensors='pt'
    )

# Prepare example
example = {
    "question": question,
    "context": context,
    "answer": answer
}

inputs = tokenize_fn(example)

# Extract and remove offset_mapping (needed for token-char alignment)
offset_mapping = inputs.pop("offset_mapping")[0]
input_ids = inputs["input_ids"][0]

# Get character-level start and end
start_char = context.index(answer)
end_char = start_char + len(answer)

# Map character positions to token indices
start_token = end_token = None
for idx, (start, end) in enumerate(offset_mapping):
    if start_token is None and start <= start_char < end:
        start_token = idx
    if start < end_char <= end:
        end_token = idx

# Fallback in case alignment fails
if start_token is None or end_token is None:
    raise ValueError("Could not find token span for the answer.")

# Inference
with torch.no_grad():
    outputs = model(**inputs)

answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()

# Decode predicted answer
predict_answer_tokens = input_ids[answer_start_index : answer_end_index + 1]
print("Predicted:", tokenizer.decode(predict_answer_tokens, skip_special_tokens=True))

Predicted: jim henson
