New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade PTL to 1.0.2 #1278
Upgrade PTL to 1.0.2 #1278
Conversation
let's hold on with this until 1.0.0 is available. |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
9e53ccd
to
09d8f23
Compare
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
|
||
|
||
def compute_topk_accuracy(correct_counts, total_counts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is used in asr/models/classification_models.py and asr/models/label_models.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to create functional metric for this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not tie such metrics to PTL itself, but keep them general and simply use them in PTL metric wrappers. Otherwise it will become harder to use them outside of PTL contexts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PTL Metrics API is independent of PTL. Take a look at this example from their docs:
from pytorch_lightning import metrics
train_accuracy = metrics.Accuracy()
valid_accuracy = metrics.Accuracy(compute_on_step=False)
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I understand that, but why do we want to make a basic utility function into something that wraps PTL of any form ? Please revert this one and simply call it from a PTL metric class if needed.
|
||
checkpoint_callback = NeMoModelCheckpoint( | ||
filepath=Path(log_dir / 'checkpoints' / '{val_loss:.2f}-{epoch}'), | ||
save_top_k=3, | ||
monitor='val_loss', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to find a way to pass these parameters to the user through the yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"val_loss" is not a given - RNNT can optionally disable it cause it is non-trivial cost to compute it over entire datasets (expensive Prediction step and Joint step cost). The default should be loss because that is understood to always be available (since backprop cannot occur without it).
There is also the case that we may not choose to supply a validation set at all (some datasets don't have a validation set, or user constructed dataset does not have a split for it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Back in 0.9.0, lightning used to assume that "val_loss" was the key that people returned from their validation step.
While moving in a more general director makes sense for lightning, should we do so for NeMo? Or should we just enforce model users to good defaults? If they choose not to use the default, then they have to manually adjust their scripts to make use of things like the ModelCheckpoint callback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So currently ModelPT has no access to ExpManager (it's called before ModelPT) nor can it ensure the list of callbacks is in the correct order to override it at a later point. Even if it could parse the callback list, it can't access any of the path values as exp manager is a stateless function.
For now, it's fine to keep it val_loss, and provide an exp manager override to chose the monitor. I can override that in RNNT configs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, good defaults is very model dependent. An experiment manager should offer the flexibility without having to completely override one of it's core tasks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think the correct way is to leave "val_loss" as default and provide an override via it's yaml config. I'll think on it more once we merge this PR
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This pull request introduces 3 alerts when merging e40b472 into fd98a89 - view on LGTM.com new alerts:
|
This pull request introduces 1 alert when merging 36b9f87 into fd98a89 - view on LGTM.com new alerts:
|
This pull request introduces 1 alert when merging d6d564d into fd98a89 - view on LGTM.com new alerts:
|
This pull request introduces 1 alert when merging ac7102a into fd98a89 - view on LGTM.com new alerts:
|
This pull request introduces 1 alert when merging 2a0f459 into fd98a89 - view on LGTM.com new alerts:
|
This pull request introduces 1 alert when merging 89c14ca into 87206f7 - view on LGTM.com new alerts:
|
This pull request introduces 1 alert when merging 5999ee9 into 87206f7 - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging a28a12d into 87206f7 - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging bf0f265 into 87206f7 - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging 4752293 into 87206f7 - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging 1bc44e9 into 87206f7 - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging a489ee6 into 87206f7 - view on LGTM.com new alerts:
|
@@ -148,7 +148,10 @@ def validation_step(self, batch, batch_idx): | |||
|
|||
preds = torch.argmax(logits, axis=-1)[subtokens_mask] | |||
labels = labels[subtokens_mask] | |||
tp, fp, fn = self.classification_report(preds, labels) | |||
self.classification_report(preds, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ericharper, can you review the changes to this file? I mostly copied from text_classification_model.py.
Specifically, can you check that the instantiation of self.classification_report makes sense in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between ['macro', 'micro', 'weighted']? Should everything just be micro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'micro' is accuracy, all these are different ways of aggregating tp, tn, fp for each label and calculating the final result. 'Macro' should be the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ClassificationReport is used in 4 models:
IntentSlotClassificationModel
TextClassificationModel
PunctuationCapitalizationModel
TokenClassificationModel
Can I ask for a review on each model so we can approve of the changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated all of the above models.
This pull request introduces 1 alert when merging 2e91818 into 87206f7 - view on LGTM.com new alerts:
|
Jenkinsfile
Outdated
// TODO: Pytorch Lightning has some issues with restoring Metric classes, asked on the lightning slack if they can | ||
// provide a simple solution. | ||
// stage('L2: Parallel NLP Examples 2') { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'NER finetuning from pretrained Test' is currently blocked by Lightning-AI/pytorch-lightning#4195
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think you might be able to use Lightning.load_from_checkpoint(..., strict=False)
to avoid the key mismatch. (Uses load_state_dict(..., strict=False)
under the hood.
This pull request introduces 1 alert when merging 71e8f6f into 910caf6 - view on LGTM.com new alerts:
|
This pull request introduces 1 alert when merging ae46780 into 910caf6 - view on LGTM.com new alerts:
|
Signed-off-by: ericharper <complex451@gmail.com>
This pull request introduces 1 alert when merging ae687b0 into 910caf6 - view on LGTM.com new alerts:
|
Signed-off-by: ericharper <complex451@gmail.com>
This pull request introduces 3 alerts when merging 6139420 into 910caf6 - view on LGTM.com new alerts:
|
Signed-off-by: ericharper <complex451@gmail.com>
This pull request introduces 3 alerts when merging bbcac81 into 910caf6 - view on LGTM.com new alerts:
|
Signed-off-by: ericharper <complex451@gmail.com>
This pull request introduces 3 alerts when merging 41e4be1 into 910caf6 - view on LGTM.com new alerts:
|
Signed-off-by: ericharper <complex451@gmail.com>
This pull request introduces 3 alerts when merging 247975d into 910caf6 - view on LGTM.com new alerts:
|
Signed-off-by: ericharper <complex451@gmail.com>
This pull request introduces 3 alerts when merging 8175ae7 into 910caf6 - view on LGTM.com new alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classification report changes LGTM.
@@ -139,18 +153,18 @@ def get_precision_recall_f1( | |||
+ '\n' | |||
) | |||
|
|||
logging.info(report) | |||
# logging.info(report) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove if not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# return { | ||
# 'train_loss': loss, | ||
# 'lr': self._optimizer.param_groups[0]['lr'] | ||
# } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove if not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tensorboard_logs = {'val_loss': avg_loss, 'exact_match': exact_match, 'f1': f1} | ||
return {'val_loss': avg_loss, 'log': tensorboard_logs} | ||
# tensorboard_logs = {'val_loss': avg_loss, 'exact_match': exact_match, 'f1': f1} | ||
# return {'val_loss': avg_loss, 'log': tensorboard_logs} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -80,7 +80,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |||
|
|||
self.loss = self.setup_loss(class_balancing=self._cfg.dataset.class_balancing) | |||
# setup to track metrics | |||
self.classification_report = ClassificationReport(len(self._cfg.label_ids), label_ids=self._cfg.label_ids) | |||
# TODO: What is the current mode? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'macro'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 'macro' is the default
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
Signed-off-by: ericharper <complex451@gmail.com>
This pull request introduces 2 alerts when merging 655fb48 into 910caf6 - view on LGTM.com new alerts:
|
Lightning Changes:
row_log_interval
->log_every_n_steps
: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#log-every-n-stepsdistributed_backend
->accelerator
: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#acceleratorlog_save_interval
->flush_logs_every_n_steps
pl.callbacks.LearningRateLogger
->pl.callbacks.LearningRateMonitor
TensorMetric
->Metric
: https://pytorch-lightning.readthedocs.io/en/latest/metrics.html#WERBPE
,WER
,TopKClassificationAccuracy
,ClassificationReport
,Perplexity
]self.log
for logging scalars instead of return dictionaries fromtraining_step
,training_epoch_end
,validation...
,test...
. The logger object should be used for logging anything other than scalars.NeMo Changes
strict
to most NeMo restoring functions to handle Loading NLP and ASR models might result inMissing key(s) in state_dict
error #1297Under the hood changes:
AttributeError: 'Trainer' object has no attribute 'configure_logger'
, nowTrainer.logger_connector.configure_logger
configure_checkpoint_callback
->callback_connector.init_default_checkpoint_callback