In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
%pip install torchinfo
%restart_python

In [0]:
import os

from bacp import BaCPLearner, BaCPTrainer, BaCPTrainingArgumentsLLM
from models import EncoderProjectionNetwork, ClassificationNetwork
from unstructured_pruning import MagnitudePrune, MovementPrune, LocalMagnitudePrune, LocalMovementPrune, WandaPrune, PRUNER_DICT, check_model_sparsity
from LLM_trainer import LLMTrainer, LLMTrainingArguments
from dataset_utils import get_glue_data, get_squad_data
from logger import Logger

import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset 
from transformers import AutoTokenizer

from tqdm import tqdm
from torchinfo import summary

from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
os.environ["HF_DATASETS_CACHE"] = "/dbfs/hf_datasets"
os.environ["TOKENIZERS_PARALLELISM"] = "false" 

from utils import *
from constants import *

device = get_device()
print(f"{device = }")
BATCH_SIZE_DISTILBERT = 64
NUM_WORKERS = 24


### DataLoader

### Create Training Script Here:

In [0]:
from transformers import AutoTokenizer, DistilBertForQuestionAnswering, get_linear_schedule_with_warmup
import torch
from tqdm import tqdm

def train(
    model,
    tokenizer,
    train_loader,
    validation_loader=None,
    output_dir="./squad_model",
    num_epochs=3,
    lr=5e-5,
    device=None
):
    
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Training on device: {device}")

    total_steps = len(train_loader) * num_epochs
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    # Epoch loop
    avg_losses = []
    for epoch in range(1, num_epochs + 1):
        batch_losses = []

        # Training phase
        model.train()
        batchloader = tqdm(train_loader, desc=f"Epoch {epoch} Training")
        for step, batch in enumerate(batchloader):
            optimizer.zero_grad()
            
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            
            loss = outputs.loss
            batch_losses.append(loss.item())

            loss.backward()
            optimizer.step()
            scheduler.step()

            batchloader.set_postfix(Running_Loss=f"{loss.item():.4f}")

        avg_loss = torch.mean(torch.tensor(batch_losses))
        avg_losses.append(avg_loss)
        print(f"Epoch {epoch} training loss: {avg_loss:.4f}")

        # Validation phase
        if validation_loader:
            model.eval()
            val_batch_losses = []
            correct_start = 0
            correct_end = 0
            exact_matches = 0
            total = 0
            
            with torch.no_grad():
                batchloader = tqdm(validation_loader, desc="Validation")
                for batch in batchloader:
                    batch = {k: v.to(device) for k, v in batch.items()}
                    outputs = model(**batch)
                    
                    # Same loss calculation as training
                    loss = outputs.loss
                    val_batch_losses.append(loss.item())
                    
                    # Get predictions
                    start_preds = outputs.start_logits.argmax(dim=1)
                    end_preds = outputs.end_logits.argmax(dim=1)
                    
                    correct_start += (start_preds == batch['start_positions']).sum().item()
                    correct_end += (end_preds == batch['end_positions']).sum().item()
                    
                    both_correct = ((start_preds == batch['start_positions']) & 
                                   (end_preds == batch['end_positions'])).sum().item()
                    exact_matches += both_correct
                    
                    total += batch['input_ids'].size(0)
            
            # Calculate metrics - same as training
            avg_val_loss = torch.mean(torch.tensor(val_batch_losses))
            start_acc = (correct_start / total) * 100
            end_acc = (correct_end / total) * 100
            exact_match_acc = (exact_matches / total) * 100
            
            print(f"Validation loss: {avg_val_loss:.4f}")
            print(f"Start accuracy: {start_acc:.2f}")
            print(f"End accuracy: {end_acc:.2f}")
            print(f"Exact match accuracy: {exact_match_acc:.2f}")
            print(f"Total samples: {total}")
            
    # Save model
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")

In [0]:
from transformers import AutoTokenizer, DistilBertForQuestionAnswering

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = DistilBertForQuestionAnswering.from_pretrained(model_name)
subset_ratio = 1.0

data = get_squad_data(tokenizer, 512, subset_ratio=subset_ratio, num_workers=24)
trainloader = data['trainloader']
valloader = data['valloader']

train(
    model,
    tokenizer,
    trainloader,
    validation_loader=valloader,
    output_dir="./squad_model",
    num_epochs=3,
    lr=5e-5,
    device=None
)