From 19a6b90f77953868e2dbf2c095474ba456df5a63 Mon Sep 17 00:00:00 2001 From: skar0 Date: Wed, 7 Feb 2024 09:54:48 +0000 Subject: [PATCH 1/5] Added right-padding to minibatch if less than expected size and removed this from final logits/labels --- weak_to_strong/train.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/weak_to_strong/train.py b/weak_to_strong/train.py index 5ba29f5..ac81a12 100644 --- a/weak_to_strong/train.py +++ b/weak_to_strong/train.py @@ -131,20 +131,33 @@ def lr_schedule_fn(step): for mbatch in to_batch( ds, minibatch_size, start=start, end=start + batch_size ): + batch_len = len(mbatch) + # If this is the last batch and it's smaller than + # the minibatch_size... + if batch_len < minibatch_size: + # Calculate number of padding examples needed + padding_size = minibatch_size - batch_len + # Create clone of the last example for padding + padding_example = {key: torch.stack([val[-1]]*padding_size) for key, val in mbatch.items()} + # Extend the mbatch with padding examples + for key, val in padding_example.items(): + mbatch[key] = torch.cat((mbatch[key], val), dim=0) input_ids = ( torch.nn.utils.rnn.pad_sequence( - [torch.tensor(ids) for ids in mbatch["input_ids"]] # type: ignore + [torch.tensor(ids) for ids in mbatch["input_ids"]] ) .transpose(0, 1) - .to(io_device) # type: ignore + .to(io_device) ) - labels = torch.tensor(mbatch["soft_label"]).to(io_device) # type: ignore + labels = torch.tensor(mbatch["soft_label"]).to(io_device) logits = model( input_ids, choice_input_ids=mbatch.get("choice_input_ids") ) - all_logits.extend(logits.to(io_device)) - all_labels.extend(labels) + # Ensure only the actual predictions and labels are extended, + # not the padded ones + all_logits.extend(logits[:batch_len].to(io_device)) + all_labels.extend(labels[:batch_len]) all_logits = torch.stack(all_logits) all_labels = torch.stack(all_labels) all_hard_labels = torch.argmax(all_labels, dim=1) From 20ab2c0133cefd2f8b2d04ff9c707bbcecb7ea83 Mon Sep 17 00:00:00 2001 From: ojh31 Date: Fri, 9 Feb 2024 21:03:46 +0000 Subject: [PATCH 2/5] Set model parallel for mistral --- weak_to_strong/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weak_to_strong/config.py b/weak_to_strong/config.py index dcbc927..683e2a9 100644 --- a/weak_to_strong/config.py +++ b/weak_to_strong/config.py @@ -120,7 +120,7 @@ class ModelConfig: ], minibatch_size_per_device=1, # this needs adjusting for GPU/dataset gradient_checkpointing=True, - model_parallel=False, + model_parallel=True, custom_kwargs={ "torch_dtype": torch.bfloat16 # we can only do this because we're using LoRA if torch.cuda.is_bf16_supported() From 313638f3e9e74ab69f8c12c8de3355415bf5e71e Mon Sep 17 00:00:00 2001 From: ojh31 Date: Mon, 12 Feb 2024 09:01:49 +0000 Subject: [PATCH 3/5] train runs now but gets suspicious 100% train accuracy despite eval accuracy 50% --- weak_to_strong/train.py | 66 ++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/weak_to_strong/train.py b/weak_to_strong/train.py index ac81a12..502c63e 100644 --- a/weak_to_strong/train.py +++ b/weak_to_strong/train.py @@ -19,6 +19,8 @@ from weak_to_strong.model import TransformerWithHead from weak_to_strong.config import ModelConfig +MBATCH_KEYS = ["input_ids", "soft_label", "choice_input_ids"] + def save(model: torch.nn.Module, save_path: str, optimizer=None, scheduler=None): # Note: If the model is wrapped by DataParallel, we need to unwrap it before saving @@ -131,17 +133,7 @@ def lr_schedule_fn(step): for mbatch in to_batch( ds, minibatch_size, start=start, end=start + batch_size ): - batch_len = len(mbatch) - # If this is the last batch and it's smaller than - # the minibatch_size... - if batch_len < minibatch_size: - # Calculate number of padding examples needed - padding_size = minibatch_size - batch_len - # Create clone of the last example for padding - padding_example = {key: torch.stack([val[-1]]*padding_size) for key, val in mbatch.items()} - # Extend the mbatch with padding examples - for key, val in padding_example.items(): - mbatch[key] = torch.cat((mbatch[key], val), dim=0) + mbatch_len = len(mbatch) input_ids = ( torch.nn.utils.rnn.pad_sequence( [torch.tensor(ids) for ids in mbatch["input_ids"]] @@ -150,14 +142,60 @@ def lr_schedule_fn(step): .to(io_device) ) labels = torch.tensor(mbatch["soft_label"]).to(io_device) + assert input_ids.shape[0] == mbatch_len, ( + f"input_ids.shape[0] ({input_ids.shape[0]}) != mbatch_len " + f"({mbatch_len})" + ) + assert labels.shape[0] == mbatch_len, ( + f"labels.shape[0] ({labels.shape[0]}) != mbatch_len " + f"({mbatch_len})" + ) + assert labels.ndim == 1, f"labels.ndim ({labels.ndim}) != 1" + assert input_ids.ndim == 2, f"input_ids.ndim ({input_ids.ndim}) != 2" + choice_input_ids = mbatch.get("choice_input_ids") + padding_size = max(0, minibatch_size - mbatch_len) + if padding_size > 0: + # If this is the last batch and it's smaller than + # the minibatch_size then add padding + input_ids = torch.cat( + [input_ids, torch.zeros( + padding_size, + input_ids.shape[1], + device=io_device, + dtype=input_ids.dtype, + )], + dim=0, + ) + labels = torch.cat( + [labels, torch.zeros( + padding_size, + labels.shape[1], + device=io_device, + dtype=labels.dtype, + )], + dim=0, + ) + if choice_input_ids is not None: + choice_input_ids = torch.cat( + [ + choice_input_ids, + torch.zeros( + padding_size, + choice_input_ids.shape[1], + device=io_device, + dtype=choice_input_ids.dtype, + ), + ], + dim=0, + ) logits = model( - input_ids, choice_input_ids=mbatch.get("choice_input_ids") + input_ids, choice_input_ids=choice_input_ids, ) # Ensure only the actual predictions and labels are extended, # not the padded ones - all_logits.extend(logits[:batch_len].to(io_device)) - all_labels.extend(labels[:batch_len]) + all_logits.extend(logits[:mbatch_len].to(io_device)) + all_labels.extend(labels[:mbatch_len]) all_logits = torch.stack(all_logits) all_labels = torch.stack(all_labels) all_hard_labels = torch.argmax(all_labels, dim=1) From 64218682cd7c4f9e4187454de0969c9dad0e0e6e Mon Sep 17 00:00:00 2001 From: ojh31 Date: Mon, 12 Feb 2024 10:16:32 +0000 Subject: [PATCH 4/5] Fixed definition of mbatch len and logged std error of train accuracy --- weak_to_strong/train.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/weak_to_strong/train.py b/weak_to_strong/train.py index 502c63e..69f4f22 100644 --- a/weak_to_strong/train.py +++ b/weak_to_strong/train.py @@ -98,6 +98,7 @@ def lr_schedule_fn(step): step = 0 losses = [] accuracies = [] + accuracy_errors = [] aurocs = [] eval_acc_dict = {} @@ -133,25 +134,15 @@ def lr_schedule_fn(step): for mbatch in to_batch( ds, minibatch_size, start=start, end=start + batch_size ): - mbatch_len = len(mbatch) + mbatch_len = len(mbatch["input_ids"]) input_ids = ( torch.nn.utils.rnn.pad_sequence( [torch.tensor(ids) for ids in mbatch["input_ids"]] ) .transpose(0, 1) .to(io_device) - ) - labels = torch.tensor(mbatch["soft_label"]).to(io_device) - assert input_ids.shape[0] == mbatch_len, ( - f"input_ids.shape[0] ({input_ids.shape[0]}) != mbatch_len " - f"({mbatch_len})" - ) - assert labels.shape[0] == mbatch_len, ( - f"labels.shape[0] ({labels.shape[0]}) != mbatch_len " - f"({mbatch_len})" - ) - assert labels.ndim == 1, f"labels.ndim ({labels.ndim}) != 1" - assert input_ids.ndim == 2, f"input_ids.ndim ({input_ids.ndim}) != 2" + ) # [batch, pos] + labels = torch.tensor(mbatch["soft_label"]).to(io_device) # [batch, num_classes] choice_input_ids = mbatch.get("choice_input_ids") padding_size = max(0, minibatch_size - mbatch_len) if padding_size > 0: @@ -206,12 +197,13 @@ def lr_schedule_fn(step): loss_tot += loss.item() loss.backward() losses.append(loss_tot) - accuracies.append( - torch.mean( - (torch.argmax(all_logits, dim=1) == all_hard_labels).to( - torch.float32 - ) - ).item() + is_correct = ( + torch.argmax(all_logits, dim=1) == all_hard_labels + ).to(torch.float32) + accuracies.append(torch.mean(is_correct).item()) + accuracy_errors.append( + np.std(is_correct.cpu().numpy()) / + np.sqrt(len(all_hard_labels)) ) try: @@ -226,6 +218,7 @@ def lr_schedule_fn(step): "progress": step / nsteps, "loss": loss_tot, "train_accuracy": accuracies[-1], + "train_accuracy_error": accuracy_errors[-1], "train_auroc": aurocs[-1], "lr": lr_scheduler.get_last_lr()[0], } From 825da72b5ed1648be8b71a1f0b2d4c1cc03291af Mon Sep 17 00:00:00 2001 From: skar0 Date: Mon, 12 Feb 2024 18:29:22 +0000 Subject: [PATCH 5/5] Checked for empty minibatch and dropped std error of accuracy --- weak_to_strong/train.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/weak_to_strong/train.py b/weak_to_strong/train.py index 69f4f22..d363a6c 100644 --- a/weak_to_strong/train.py +++ b/weak_to_strong/train.py @@ -19,8 +19,6 @@ from weak_to_strong.model import TransformerWithHead from weak_to_strong.config import ModelConfig -MBATCH_KEYS = ["input_ids", "soft_label", "choice_input_ids"] - def save(model: torch.nn.Module, save_path: str, optimizer=None, scheduler=None): # Note: If the model is wrapped by DataParallel, we need to unwrap it before saving @@ -98,7 +96,6 @@ def lr_schedule_fn(step): step = 0 losses = [] accuracies = [] - accuracy_errors = [] aurocs = [] eval_acc_dict = {} @@ -135,14 +132,16 @@ def lr_schedule_fn(step): ds, minibatch_size, start=start, end=start + batch_size ): mbatch_len = len(mbatch["input_ids"]) + if mbatch_len == 0: + continue input_ids = ( torch.nn.utils.rnn.pad_sequence( [torch.tensor(ids) for ids in mbatch["input_ids"]] ) .transpose(0, 1) .to(io_device) - ) # [batch, pos] - labels = torch.tensor(mbatch["soft_label"]).to(io_device) # [batch, num_classes] + ) # [batch, pos] + labels = torch.tensor(mbatch["soft_label"]).to(io_device) # [batch, num_classes] choice_input_ids = mbatch.get("choice_input_ids") padding_size = max(0, minibatch_size - mbatch_len) if padding_size > 0: @@ -187,6 +186,8 @@ def lr_schedule_fn(step): # not the padded ones all_logits.extend(logits[:mbatch_len].to(io_device)) all_labels.extend(labels[:mbatch_len]) + if len(all_logits) == 0: + continue all_logits = torch.stack(all_logits) all_labels = torch.stack(all_labels) all_hard_labels = torch.argmax(all_labels, dim=1) @@ -201,10 +202,6 @@ def lr_schedule_fn(step): torch.argmax(all_logits, dim=1) == all_hard_labels ).to(torch.float32) accuracies.append(torch.mean(is_correct).item()) - accuracy_errors.append( - np.std(is_correct.cpu().numpy()) / - np.sqrt(len(all_hard_labels)) - ) try: auroc = roc_auc_score(all_hard_labels.cpu(), all_logprobs.cpu()) @@ -218,7 +215,6 @@ def lr_schedule_fn(step): "progress": step / nsteps, "loss": loss_tot, "train_accuracy": accuracies[-1], - "train_accuracy_error": accuracy_errors[-1], "train_auroc": aurocs[-1], "lr": lr_scheduler.get_last_lr()[0], }