Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,29 +735,31 @@ def _maybe_resolve_length(value, source_name):
collate_data(sorted_dataset[start: start + batch_size], self.tokenizer.pad_token_id)
for start in range(0, len(sorted_dataset), batch_size)
]
else:
new_calibration_dataset_batched = [
{"input_ids": torch.tensor(block["input_ids"], dtype=torch.long)}
for block in sorted_dataset
]

# total tokens counters
total_padded = 0
total_non_padded = 0
# total tokens counters
total_padded = 0
total_non_padded = 0

for batch in new_calibration_dataset_batched:
# attention_mask is shape [batch_size, seq_len]
mask = batch["attention_mask"]
for batch in new_calibration_dataset_batched:
# attention_mask is shape [batch_size, seq_len]
mask = batch["attention_mask"]

# count where mask == 0 (padded tokens)
total_padded += (mask == 0).sum().item()
# count where mask == 0 (padded tokens)
total_padded += (mask == 0).sum().item()

# count where mask == 1 (non-padded tokens)
total_non_padded += (mask == 1).sum().item()
# count where mask == 1 (non-padded tokens)
total_non_padded += (mask == 1).sum().item()

log.info(f"Calibration: Total padded tokens: {total_padded}")
log.info(f"Calibration: Total non-padded tokens: {total_non_padded}")
log.info(f"Calibration: Total tokens: {total_non_padded + total_padded}")
log.info(f"Calibration: Total padded tokens: {total_padded}")
log.info(f"Calibration: Total non-padded tokens: {total_non_padded}")
log.info(f"Calibration: Total tokens: {total_non_padded + total_padded}")
else:
new_calibration_dataset_batched = [
{
"input_ids": torch.tensor(block["input_ids"], dtype=torch.long),
}
for block in sorted_dataset
]

return new_calibration_dataset_batched

Expand Down