diff --git a/gptqmodel/utils/data.py b/gptqmodel/utils/data.py index b5de25f5e..7e30f0ae4 100644 --- a/gptqmodel/utils/data.py +++ b/gptqmodel/utils/data.py @@ -150,7 +150,7 @@ def collate_data(batch: List[Dict[str, List[List[int]]]], pad_token_id: int) -> Each element of `batch` looks like: { "input_ids": List[List[int]], # rows - "attention_mask": List[List[int]], # rows + "attention_mask": List[List[int]], # rows (0/1 ints, cast to bool here) } """ # Flatten rows across all items in the outer batch @@ -166,9 +166,9 @@ def collate_data(batch: List[Dict[str, List[List[int]]]], pad_token_id: int) -> for r in range(len(ids_list)): ids = torch.as_tensor(ids_list[r], dtype=torch.long) - msk = torch.as_tensor(msk_list[r], dtype=torch.long) + # make mask boolean immediately + msk = torch.as_tensor(msk_list[r], dtype=torch.bool) - # ensure pre-pad lengths match within the row if ids.numel() != msk.numel(): raise ValueError("Row has mismatched lengths between input_ids and attention_mask") @@ -179,24 +179,30 @@ def collate_data(batch: List[Dict[str, List[List[int]]]], pad_token_id: int) -> max_len = max(t.numel() for t in rows_ids) if rows_ids else 0 # Right-pad each row to global max_len - def right_pad(row: torch.Tensor, pad_value: int) -> torch.Tensor: + def right_pad(row: torch.Tensor, pad_value, dtype=None) -> torch.Tensor: pad_len = max_len - row.numel() if pad_len <= 0: return row - return torch.cat([row, torch.full((pad_len,), pad_value, dtype=row.dtype, device=row.device)], dim=0) - - padded_ids = [right_pad(t, pad_token_id) for t in rows_ids] - padded_msk = [right_pad(t, 0) for t in rows_mask] + return torch.cat( + [ + row, + torch.full((pad_len,), pad_value, dtype=dtype or row.dtype, device=row.device), + ], + dim=0, + ) + + padded_ids = [right_pad(t, pad_token_id, dtype=torch.long) for t in rows_ids] + # pad masks with False, not 0 + padded_msk = [right_pad(t, False, dtype=torch.bool) for t in rows_mask] # Stack into [total_rows_in_batch, max_len] input_ids = torch.stack(padded_ids, dim=0) if padded_ids else torch.empty((0, 0), dtype=torch.long) - attention_mask = torch.stack(padded_msk, dim=0) if padded_msk else torch.empty((0, 0), dtype=torch.long) + attention_mask = torch.stack(padded_msk, dim=0) if padded_msk else torch.empty((0, 0), dtype=torch.bool) - out = { + return { "input_ids": input_ids, "attention_mask": attention_mask, } - return out def get_dataloader( diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index c5c0666a2..f8175aebc 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -21,7 +21,7 @@ class TestLlama3_2(ModelTest): ACT_GROUP_AWARE = True DESC_ACT = False DATASET_SIZE = 1024 - DATASET_SORT = "asc" + DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 # b1 = 0.315, b4 = 0.3106, b8 = 0.3148, b32 = 0.3148, b16 = 0.3234