Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Tuan Lai <tuanl@nvidia.com>
  • Loading branch information
Tuan Lai committed Jul 3, 2021
1 parent aafef49 commit 9325b04
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ tagger_model:
name: WarmupAnnealing

# pytorch lightning args
monitor: val_token_precision
monitor: val_loss
reduce_on_plateau: false

# scheduler config override
Expand All @@ -48,7 +48,7 @@ tagger_exp_manager:
create_checkpoint_callback: True
checkpoint_callback_params:
save_top_k: 3
monitor: "val_token_precision"
monitor: "val_loss"
mode: "max"
save_best_model: true
always_save_nemo: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def main(cfg: DictConfig) -> None:
tagger_model.save_to(cfg.tagger_model.nemo_path)
logging.info('Training finished!')

logging.info('Testing the trained tagger...')
tagger_model.setup_test_data(test_data_config=cfg.data.test_ds)
tagger_trainer.test(tagger_model)
logging.info('Testing finished!')

# Train the decoder
if cfg.decoder_model.do_training:
logging.info("================================================================================================")
Expand All @@ -111,5 +116,10 @@ def main(cfg: DictConfig) -> None:
decoder_model.save_to(cfg.decoder_model.nemo_path)
logging.info('Training finished!')

logging.info('Testing the trained decoder...')
decoder_model.setup_test_data(test_data_config=cfg.data.test_ds)
decoder_trainer.test(decoder_model)
logging.info('Testing finished!')

if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def training_step(self, batch, batch_idx):
"""
# Apply Transformer
outputs = self.model(
input_ids=batch['input_ids'].to(self.device),
decoder_input_ids=batch['decoder_input_ids'].to(self.device),
attention_mask=batch['attention_mask'].to(self.device),
labels=batch['labels'].to(self.device),
input_ids=batch['input_ids'],
decoder_input_ids=batch['decoder_input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
train_loss = outputs.loss

Expand All @@ -77,10 +77,10 @@ def validation_step(self, batch, batch_idx):

# Apply Transformer
outputs = self.model(
input_ids=batch['input_ids'].to(self.device),
decoder_input_ids=batch['decoder_input_ids'].to(self.device),
attention_mask=batch['attention_mask'].to(self.device),
labels=batch['labels'].to(self.device),
input_ids=batch['input_ids']
decoder_input_ids=batch['decoder_input_ids']
attention_mask=batch['attention_mask']
labels=batch['labels']
)
val_loss = outputs.loss

Expand All @@ -94,11 +94,7 @@ def validation_epoch_end(self, outputs):
:param outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
self.log('val_loss', avg_loss)

return {
'val_loss': avg_loss,
}
self.log('val_loss', avg_loss, sync_dist=True)

def test_step(self, batch, batch_idx):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,11 @@ def training_step(self, batch, batch_idx):
num_labels = self.num_labels

# Apply Transformer
input_ids = batch['input_ids'].to(self.device)
input_masks = batch['attention_mask'].to(self.device)
tag_logits = self.model(input_ids, input_masks).logits
tag_logits = self.model(batch['input_ids'], batch['attention_mask']).logits

# Loss
tag_labels = batch['labels'].view(-1).to(self.device)
train_loss = self.loss_fct(tag_logits.view(-1, num_labels), tag_labels)
train_loss = self.loss_fct(tag_logits.view(-1, num_labels),
batch['labels'].view(-1))

lr = self._optimizer.param_groups[0]['lr']
self.log('train_loss', train_loss)
Expand All @@ -89,50 +87,52 @@ def validation_step(self, batch, batch_idx):
passed in as `batch`.
"""
# Apply Transformer
input_ids = batch['input_ids'].to(self.device)
input_masks = batch['attention_mask'].to(self.device)
tag_logits = self.model(input_ids, input_masks).logits
tag_logits = self.model(batch['input_ids'], batch['attention_mask']).logits
tag_preds = torch.argmax(tag_logits, dim=2)

# Loss
tag_labels = batch['labels'].to(self.device)

# Update classification_report
predictions, labels = tag_preds.tolist(), tag_labels.tolist()
for prediction, label in zip(predictions, labels):
cur_preds = [p for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID]
cur_labels = [l for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID]
self.classification_report(torch.tensor(cur_preds).to(self.device),
torch.tensor(cur_labels).to(self.device))
val_loss = self.loss_fct(tag_logits.view(-1, self.num_labels),
batch['labels'].view(-1))

return {'val_loss': val_loss}

def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs.
:param outputs: list of individual outputs of each validation step.
"""
# calculate metrics and classification report
precision, _, _, report = self.classification_report.compute()

logging.info(report)

self.log('val_token_precision', precision)

self.classification_report.reset()
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
self.log('val_loss', avg_loss, sync_dist=True)

def test_step(self, batch, batch_idx):
"""
Lightning calls this inside the test loop with the data from the test dataloader
passed in as `batch`.
"""
return self.validation_step(batch, batch_idx)
# Apply Transformer
tag_logits = self.model(batch['input_ids'], batch['attention_mask']).logits
tag_preds = torch.argmax(tag_logits, dim=2)

# Update classification_report
predictions, labels = tag_preds.tolist(), batch['labels'].tolist()
for prediction, label in zip(predictions, labels):
cur_preds = [p for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID]
cur_labels = [l for (p, l) in zip(prediction, label) if l != constants.LABEL_PAD_TOKEN_ID]
self.classification_report(torch.tensor(cur_preds).to(self.device),
torch.tensor(cur_labels).to(self.device))

def test_epoch_end(self, outputs):
"""
Called at the end of test to aggregate outputs.
:param outputs: list of individual outputs of each test step.
"""
return self.validation_epoch_end(outputs)
# calculate metrics and classification report
precision, _, _, report = self.classification_report.compute()

logging.info(report)
self.log('test_token_precision', precision)

self.classification_report.reset()

# Functions for inference
@torch.no_grad()
Expand Down

0 comments on commit 9325b04

Please sign in to comment.