In [2]:
import torch

effective_batch_size = 64
total_batch_size = 16
grad_accum_steps = effective_batch_size // total_batch_size

# Create training loop function
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               device):

    # Put the model in train mode
    model.train()

    # Setup train loss, and train accuracy
    train_loss, train_acc = 0.0, 0.0

    # Optimizer zero_grad
    optimizer.zero_grad()

    # Loop through data loader
    for step, batch in enumerate(dataloader):

        # support both style -> (input_ids, attention_mask, labels) or dicts
        if isinstance(batch, (list, tuple)) and len(batch) == 3:
            input_ids, attention_mask, labels = batch
        else:
            # if batch is a dict-like (from tokenizer), unpack
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        # 1. Forward pass and add mixed_precision
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            y_pred = model(input_ids, attention_mask)
            loss = loss_fn(y_pred, labels)
            loss = loss / grad_accum_steps # Normalize loss

        # accumulate gradients
        loss.backward()
        train_loss += loss.detach()

        if (step + 1) % grad_accum_steps == 0:
            norm = torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            # Calculate and accumulate accuracy
            with torch.no_grad():
                y_pred_class = torch.argmax(y_pred, dim=1)
                acc = (y_pred_class == labels).float().mean().item()
                train_acc += acc


    # Adjust metrics to get average loss and accuracy per batch
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / (len(dataloader) / grad_accum_steps)

    return train_loss, train_acc