Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: last minibatch padding #8

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion weak_to_strong/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class ModelConfig:
],
minibatch_size_per_device=1, # this needs adjusting for GPU/dataset
gradient_checkpointing=True,
model_parallel=False,
model_parallel=True,
custom_kwargs={
"torch_dtype": torch.bfloat16 # we can only do this because we're using LoRA
if torch.cuda.is_bf16_supported()
Expand Down
68 changes: 54 additions & 14 deletions weak_to_strong/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,63 @@ def lr_schedule_fn(step):
for mbatch in to_batch(
ds, minibatch_size, start=start, end=start + batch_size
):
mbatch_len = len(mbatch["input_ids"])
if mbatch_len == 0:
continue
input_ids = (
torch.nn.utils.rnn.pad_sequence(
[torch.tensor(ids) for ids in mbatch["input_ids"]] # type: ignore
[torch.tensor(ids) for ids in mbatch["input_ids"]]
)
.transpose(0, 1)
.to(io_device) # type: ignore
)
labels = torch.tensor(mbatch["soft_label"]).to(io_device) # type: ignore
.to(io_device)
) # [batch, pos]
labels = torch.tensor(mbatch["soft_label"]).to(io_device) # [batch, num_classes]
choice_input_ids = mbatch.get("choice_input_ids")
padding_size = max(0, minibatch_size - mbatch_len)
if padding_size > 0:
# If this is the last batch and it's smaller than
# the minibatch_size then add padding
input_ids = torch.cat(
[input_ids, torch.zeros(
padding_size,
input_ids.shape[1],
device=io_device,
dtype=input_ids.dtype,
)],
dim=0,
)
labels = torch.cat(
[labels, torch.zeros(
padding_size,
labels.shape[1],
device=io_device,
dtype=labels.dtype,
)],
dim=0,
)
if choice_input_ids is not None:
choice_input_ids = torch.cat(
[
choice_input_ids,
torch.zeros(
padding_size,
choice_input_ids.shape[1],
device=io_device,
dtype=choice_input_ids.dtype,
),
],
dim=0,
)
logits = model(
input_ids, choice_input_ids=mbatch.get("choice_input_ids")
input_ids, choice_input_ids=choice_input_ids,
)

all_logits.extend(logits.to(io_device))
all_labels.extend(labels)
# Ensure only the actual predictions and labels are extended,
# not the padded ones
all_logits.extend(logits[:mbatch_len].to(io_device))
all_labels.extend(labels[:mbatch_len])
if len(all_logits) == 0:
continue
all_logits = torch.stack(all_logits)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One edge case to consider is what happens if mbatch_len is 0

all_labels = torch.stack(all_labels)
all_hard_labels = torch.argmax(all_labels, dim=1)
Expand All @@ -155,13 +198,10 @@ def lr_schedule_fn(step):
loss_tot += loss.item()
loss.backward()
losses.append(loss_tot)
accuracies.append(
torch.mean(
(torch.argmax(all_logits, dim=1) == all_hard_labels).to(
torch.float32
)
).item()
)
is_correct = (
torch.argmax(all_logits, dim=1) == all_hard_labels
).to(torch.float32)
accuracies.append(torch.mean(is_correct).item())

try:
auroc = roc_auc_score(all_hard_labels.cpu(), all_logprobs.cpu())
Expand Down