Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting training and validation loss #122

Closed
simonasantamaria opened this issue Jun 13, 2020 · 23 comments
Closed

Plotting training and validation loss #122

simonasantamaria opened this issue Jun 13, 2020 · 23 comments
Labels
Frequently Asked Questions Frequently Asked Questions question A general question about the library

Comments

@simonasantamaria
Copy link

I would like to plot training and validation loss over the training iterations. I'm using the hooks.get_loss_history() and working with record-keeper to visualize the loss. It's working but I'm not able to plot the training and validation loss in the same plot and I'm not sure which loss I am plotting with hooks.get_loss_history() in the first place. Would be grateful for any advice, thanks!

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jun 13, 2020

Losses are only computed on the training set, not the validation set. The reason is that for pair/triplet based losses, it doesn't really make sense to compute a single value for the entire set, unless you're willing to form every possible pair/triplet. Plus, some loss values don't correlate with accuracy at all.

So everything returned by get_loss_history corresponds to the training set. The function will return a dictionary where one of the keys is total_loss, and the rest of the keys are the components of that total loss. If you're using the MetricLossOnly trainer, then the only component is metric_loss, so total_loss and metric_loss will be the same in that case.

To compare train/val performance, you can get the tester to compute accuracy for both the training set and validation set, by changing the dataset_dict.

dataset_dict = {"train": train_dataset, "val": val_dataset}
model_folder = "example_saved_models"
tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook=hooks.end_of_testing_hook)
end_of_epoch_hook = hooks.end_of_epoch_hook(tester, dataset_dict, model_folder)

And then use get_accuracy_history to plot the numbers.

train_acc = hooks.get_accuracy_history(tester, "train") 
val_acc = hooks.get_accuracy_history(tester, "val")

# one way you could extract the information to plot
train_epochs, val_epochs = train_acc.pop("epoch"), val_acc.pop("epoch")
assert train_epochs == val_epochs

for acc_name in train_acc.keys():
    curr_train_acc = train_acc[acc_name]
    curr_val_acc = val_acc[acc_name]
    # add to plot here #

By default, only the primary metric will be returned by get_accuracy_history. But you can also get multiple metrics if you want. See the bottom of this page for more info.

If you want to make a plot at the end of every epoch, you should be able to do something like this:

dataset_dict = {"train": train_dataset, "val": val_dataset}
model_folder = "example_saved_models"
tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook=hooks.end_of_testing_hook)
helper_end_of_epoch_hook = hooks.end_of_epoch_hook(tester, dataset_dict, model_folder)

def end_of_epoch_hook(trainer):
    should_continue_training = helper_end_of_epoch_hook(trainer)
    train_acc = hooks.get_accuracy_history(tester, "train") 
    val_acc = hooks.get_accuracy_history(tester, "val")
    # do your plotting here #
    return should_continue_training

@KevinMusgrave KevinMusgrave added Frequently Asked Questions Frequently Asked Questions question A general question about the library labels Jun 13, 2020
@KevinMusgrave
Copy link
Owner

Closing for now. Please reopen if you have more questions.

@KennyTC
Copy link

KennyTC commented Jul 8, 2021

I am curious about the value of x-axis? I am sure it not the number of epochs, but it also seem not number of iteration.

I set batch_size = 32 in trainer, and my train data has 458 images --> there will be 15 batches for each epoch
And my training has stop after 25 epoch (see the accuracy plot in1st image). So if x-axis is the number of iteration, it should be 15*25 = 375. But according to the plot, the length of loss_history["metric_loss"] is just 338 (the 2nd image)

image

image

  1. How can I get the loss after each epoch instead of each iteration?
  2. I see that the best model that you save is the model at epoch 20. From 1st image, I know that it is not the model with the best accuracy. Is it the epoch with the smallest loss?
    image

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jul 8, 2021

I set batch_size = 32 in trainer, and my train data has 458 images --> there will be 15 batches for each epoch

Actually it will be floor(458/32), which is 14 batches.

And my training has stop after 25 epoch (see the accuracy plot in1st image). So if x-axis is the number of iteration, it should be 15*25 = 375. But according to the plot, the length of loss_history["metric_loss"] is just 338 (the 2nd image)

The x-axis should be number of iterations, so it should have length 14*25 = 350. It looks like the plot stops short of 350, but have you manually checked the length to be sure?

print(len(loss_history["metric_loss"]))
  1. How can I get the loss after each epoch instead of each iteration?

That functionality doesn't exist currently. As a workaround you could try taking the average of each 14-element chunk in loss_history.

  1. I see that the best model that you save is the model at epoch 20. From 1st image, I know that it is not the model with the best accuracy. Is it the epoch with the smallest loss?

The "best" model is based on the primary_metric used by HookContainer. The default primary metric is "mean_average_precision_at_r". Can you try plotting that instead and see if the peak validation accuracy is at the 20th epoch?

Or you can change the primary metric to the one that you're plotting, which is AMI. Then the best model will be chosen based on AMI.

hook = logging_presets.get_hook_container(record_keeper, primary_metric="AMI")

@KennyTC
Copy link

KennyTC commented Jul 9, 2021

Hi, thank you for your reply.

  1. First, len(loss_history["metric_loss"]) and the calculation seems not match. E.g
    I try batch_size=16 (batch_size of trainer), my len(train_data)=458, and run epoch=50 (go until 50th epoch)
    so the iteration should be floor(458/16)*50=1400, but I check len(loss_history["metric_loss"])=1350. There is 50 iterations difference.
    I can take average of each (458/16)-element chunk in loss_history. But since the total iterations not match, so I wonder if I calculate like that, it may be wrong.

  2. Yes, I have checked and it seems that the model is saved based on the default metric.

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jul 9, 2021

Your first post actually shows training went for 26 epochs (not 25). If it ran for 338 iterations, then there were 13 batches per epoch, which implies a training size between 416 and 447.

Now you've trained for 50 epochs. If it ran for 1350 iterations, then there were 27 batches per epoch, which implies a training size between 432 and 447.

Can you double check the training set size, and also show the code where you initialize/run the hooks, tester, and trainer?

Also a screenshot of the training logs would be helpful because it should have progress bars that indicate the number of iterations per epoch.

@KennyTC
Copy link

KennyTC commented Jul 9, 2021

Can you double check the training set size, and also show the code where you initialize/run the hooks, tester, and trainer?

Yes, you right. Each epoch there was 27 batches. But I sure my train data is 458 as the image.

image

@KevinMusgrave
Copy link
Owner

Can you show the part where you initialize the sampler?

If you're using MPerClassSampler, then what's probably happening is that it is reducing the epoch size to ensure the list is a multiple of m.

@KennyTC
Copy link

KennyTC commented Jul 9, 2021

Yes, I use m=4 in the MPerClassSampler. But as I asked you in different thread, each batch still has 16 images (the batch_size value in trainer), right?

image

@KevinMusgrave
Copy link
Owner

Yes, it's just that the length of the iterable returned by MPerClassSampler gets truncated such that it is a multiple of m*num_unique_labels. See here:

self.length_of_single_pass = self.m_per_class * len(self.labels)
self.list_size = length_before_new_iter
if self.batch_size is None:
if self.length_of_single_pass < self.list_size:
self.list_size -= (self.list_size) % (self.length_of_single_pass)

So that's why the number of iterations per epoch is slightly less than expected.

@KennyTC
Copy link

KennyTC commented Jul 9, 2021

Thanks to make clear. Very helpful.

@KennyTC
Copy link

KennyTC commented Jul 13, 2021

Sorry to bother you again.

Yes, you right. Each epoch there was 27 batches.

How can I calculate the number 27 automatically? given my len(train_data)=458, m=4, batch_size of trainer = 16.
I would like to calculate average loss after each epoch, so I need the number to be calculated automatically.

@KevinMusgrave
Copy link
Owner

iters_per_epoch = len(sampler) // batch_size

@talbenha
Copy link

Hi @KevinMusgrave ,
I tried to plot the training and validation loss together on the same plot in tensorboard using your explanation, however, I can't access the validation loss as I see. Also, from your explanation, I can access only the 'mean_average_precision_at_r_level0' accuracy provided to the training or validation data.

Can you please guide me or provide a code on how to plot the training and validation loss together?

Thanks,
Tal

@talbenha
Copy link

Thanks to your response @KevinMusgrave,
In the attached link I see only one graph loss (I guess it is the training loss) and two accuracies ('mean_average_precision_at_r_level0') for the training or validation data as I mentioned in the previous comments.

How can I access both training losses (train and validation) and track them during training in the tensorboard?

@KevinMusgrave
Copy link
Owner

Validation loss isn't computed because there is no clear definition for it with tuple-based losses. For example, if you're using triplet loss, you wouldn't want to use all triplets (there will be too many). You could try a random sampling of triplets and compute the loss based on that, but that functionality isn't built in to this library.

@talbenha
Copy link

talbenha commented Aug 22, 2021

OK thanks @KevinMusgrave, can you guide me please on how to do this?
how can I try a random sampling of triplets and compute the loss? what modification is needed? adding loss and miner in the trainer for the validation dataset? I think that it is trivial functionality for training NN.

@KevinMusgrave
Copy link
Owner

You can loop through the validation set:

loss = 0
for data, labels in val_loader:
    embeddings = model(data)
    loss += loss_fn(embeddings, labels)
loss /= len(val_loader)

I think that it is trivial functionality for training NN.

It's not included in this library because the loss value is often meaningless, i.e. triplet loss can be flat from the beginning to end of training, but validation accuracy can still go up.

@KevinMusgrave
Copy link
Owner

If you're using the trainers and hooks you can do this:

import tqdm
def end_of_testing_hook(tester):
    for split, (embeddings, labels) in tester.embeddings_and_labels.items():
        dataset = common_functions.EmbeddingDataset(embeddings.cpu().numpy(), labels.squeeze(1).cpu().numpy())
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)
        total_loss = 0
        with torch.no_grad():
            print(f"getting loss for {split} set")
            for E, L in tqdm.tqdm(dataloader):
                total_loss += loss(E, L)
        total_loss /= len(dataloader)
        tester.all_accuracies[split]["loss"] = total_loss
    hooks.end_of_testing_hook(tester)


# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(end_of_testing_hook = end_of_testing_hook, 
                                            dataloader_num_workers = 2,
                                            accuracy_calculator=AccuracyCalculator(k="max_bin_count"))

@talbenha
Copy link

Thanks @KevinMusgrave, it works!

@KennyTC
Copy link

KennyTC commented Feb 25, 2022

Hi @KevinMusgrave

I understand the point that with ConstrativeLoss, or TripleLoss, the input will be a pair / or a triple images, so maybe no clear definition of how to calculate it.

But how about the case of ArcFace/cosface/adacos? I found one fully implementation of model based on ArcFace/cosface/adacos (https://www.kaggle.com/tanulsingh077/pytorch-metric-learning-pipeline-only-images)
In this implementation, they decide model based on training loss and validation loss.
Does the way you use ArcFace is the same as the link above? If the same, should the loss on validation set be included in the library?
Thanks a lot.

@KevinMusgrave
Copy link
Owner

You raise a good point regarding classification losses. And since it seems common to compute the average loss during validation, it's probably worth adding to the library. I've created a separate issue to keep track of this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Frequently Asked Questions Frequently Asked Questions question A general question about the library
Projects
None yet
Development

No branches or pull requests

4 participants