Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use ArcFaceLoss with trainer? #355

Closed
KennyTC opened this issue Jul 14, 2021 · 21 comments
Closed

How to use ArcFaceLoss with trainer? #355

KennyTC opened this issue Jul 14, 2021 · 21 comments
Labels
question A general question about the library

Comments

@KennyTC
Copy link

KennyTC commented Jul 14, 2021

I am starting using ArcFaceLoss, but not quite understand how to use it. Following sample code, I think it should be like this:

# efnet will output embeddings_num
efnet = models.create_model("efficient_net", embeddings_num).to(device)
optimizer = optim.Adam(efnet.parameters(), lr=0.00001, weight_decay=0.001)

# Set the arcface loss function
loss_func = losses.ArcFaceLoss(num_classes=9, embedding_size=embeddings_num).to(device)
loss_optimizer = optim.Adam(loss_func.parameters(), lr=0.00001) 

# Set the mining function
miner = miners.MultiSimilarityMiner(epsilon=0.1)
# Set the dataloader sampler
sampler = samplers.MPerClassSampler(train_data.targets, m=4, length_before_new_iter=len(train_data))

But then, I am stuck on defining the models and loss_funcs dictionary.

# Package the above stuff into dictionaries.
models = {"trunk": efnet} ???
optimizers = {"trunk_optimizer": optimizer, "arc_optimizer": loss_optimizer}
loss_funcs = ??
mining_funcs = {"tuple_miner": miner}

and the rest the same as your MetricLossOnly code sample?

from pytorch_metric_learning import losses, miners, samplers, trainers, testers
# batch_size=32
trainer = trainers.MetricLossOnly(models,
                                  optimizers,
                                  batch_size,
                                  loss_funcs,
                                  mining_funcs,
                                  train_data,
                                  sampler=sampler,
                                  dataloader_num_workers = 24,
                                  end_of_iteration_hook = hooks.end_of_iteration_hook,
                                  end_of_epoch_hook = end_of_epoch_hook)
@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jul 14, 2021

I think this should work:

models = {"trunk": efnet}
optimizers = {"trunk_optimizer": optimizer, "metric_loss_optimizer": loss_optimizer}
loss_funcs = {"metric_loss": loss_func}
mining_funcs = {"tuple_miner": miner}

@KennyTC
Copy link
Author

KennyTC commented Jul 15, 2021

it worked. Thank you for your speedy response.

@KevinMusgrave KevinMusgrave added the question A general question about the library label Jul 15, 2021
@KennyTC
Copy link
Author

KennyTC commented Jul 16, 2021

Sorry to bother you again.

Following up this topic, I have trained my model using Arcface (Epoch = 10, iter_per_epoch = 18, num_classes = 9). I saw that

  1. The loss is quite high comparing to other loss function. And sometimes it drops significantly. Does it usual when using Arcface?
    image

  2. The accuracy (precision_at_1) is quite good (at 98%, 99%)
    image

And after training, I see the the model can cluster pretty well
image

I use this model to test on test set, the precision_at_1 is around 82%, lower than I expected. My question is

  1. Was my training weird? because the loss is still very high and very fluctuated?
  2. Should I rely on loss (lower is better) or accuracy (higher is better) or cluster (more separate is better) to decide good model? Because after training for a while, I see sometimes
  • the model clusters very good, but not so high accuracy, OR very high accuracy (92%) but not cluster very good.

  • very high accuracy, but loss still high (this case).

  • ...

@KevinMusgrave
Copy link
Owner

  1. The loss is quite high comparing to other loss function. And sometimes it drops significantly. Does it usual when using Arcface?

The high loss value is normal for ArcFace. I think the big spikes downward are caused by the tuple miner. If you get rid of the miner, the loss curve should be smoother.

  1. The accuracy (precision_at_1) is quite good (at 98%, 99%)
    And after training, I see the the model can cluster pretty well
    I use this model to test on test set, the precision_at_1 is around 82%, lower than I expected. My question is

Maybe your train/val sets are more similar to each other than your val/test sets?

  1. Should I rely on loss (lower is better) or accuracy (higher is better) or cluster (more separate is better) to decide good model? Because after training for a while, I see sometimes
  • the model clusters very good, but not so high accuracy, OR very high accuracy (92%) but not cluster very good.
  • very high accuracy, but loss still high (this case).
  • ...

I wouldn't be concerned with the loss value. I would focus on maximizing accuracy, and not clustering (though they are usually highly correlated). I think very tight clustering does not generalize well.

@KennyTC
Copy link
Author

KennyTC commented Jul 16, 2021

The high loss value is normal for ArcFace. I think the big spikes downward are caused by the tuple miner. If you get rid of the miner, the loss curve should be smoother.

I will try.

Maybe your train/val sets are more similar to each other than your val/test sets?

You are right.
I am trying to build face recognition models. I use data of 9 different ids to train (and validate). So train/vals images are very similar (because they belong to the same ids). For the test, I use other different 9 ids (take 16 images/id as gallery, the rest is for query). So train/val set is more similar than val/test set.
Is it a good idea to choose ids (in val data) # ids (in train data)? Maybe model will generalize better?

I wouldn't be concerned with the loss value. I would focus on maximizing accuracy, and not clustering (though they are usually highly correlated). I think very tight clustering does not generalize well.

I see.

Just one more thing to confirm.
During inference, should I used only efnet to create inference model.
inference_model = InferenceModel(efnet, match_finder=match_finder)
or
should I do something like: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/117#issuecomment-647505992, i.e.

from pytorch_metric_learning.losses import ArcFaceLoss
loss_func = ArcFaceLoss(num_classes, embedding_size)
loss_func.load_state_dict(torch.load(<path_goes_here>))
combine_model = CombinedModel(efnet, loss_func)
inference_model = InferenceModel(combine_model, match_finder=match_finder)

Thanks you!

@KevinMusgrave
Copy link
Owner

I am trying to build face recognition models. I use data of 9 different ids to train (and validate). So train/vals images are very similar (because they belong to the same ids). For the test, I use other different 9 ids (take 16 images/id as gallery, the rest is for query). So train/val set is more similar than val/test set.
Is it a good idea to choose ids (in val data) # ids (in train data)? Maybe model will generalize better?

I'm not sure what you mean. Can you give an example?

Just one more thing to confirm.
During inference, should I used only efnet to create inference model.
inference_model = InferenceModel(efnet, match_finder=match_finder)

Yes, use only efnet

@KennyTC
Copy link
Author

KennyTC commented Jul 16, 2021

I'm not sure what you mean. Can you give an example?

For example, I have data of 9 dogs. Each folder contains the images of 1 dog.

(1) what I am doing is: split these 9 dogs into train/val set (that is why I said train/val images are very similar)

(2) What I mean by "Is it a good idea to choose ids (in val data) # ids (in train data)" is: Use images of these 9 dogs for train set, and use images of different dogs for val set.

Maybe this will make the model to generalize better?

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jul 16, 2021

If your train and test sets are class-disjoint (i.e. they don't share any classes), then it would be a good idea to do the same for your train/val split. Then the val accuracy will be more correlated with your test accuracy. For example, you could use 6 classes for training, and 3 classes for validation.

One issue with this is that you've made training more difficult, because you're excluding those 3 classes from the training set. So you could try an ensemble approach, for example:

Train 3 different models with the following train/val splits:

  • train [0, 1, 2, 3, 4, 5], val [6, 7, 8]
  • train [3, 4, 5, 6, 7, 8], val [0, 1, 2]
  • train [6, 7, 8, 0, 1, 2], val [3, 4, 5]

This way, even though the individual models only train on 6 classes and validate on 3, the 3 of them together have trained and validated on all 9.

Load the best checkpoint of each training run. Then to evaluate on the test set, concatenate the embeddings of the 3 models.

@KennyTC
Copy link
Author

KennyTC commented Jul 17, 2021

Very helpful advice. I have never thought about ensemble approach for this problem.
If I have trained 3 models above, then the inference will be something like this?

indexer = FaissIndexer()
indexer.train_index(gallery_data)
inference_model1 = InferenceModel(model1, match_finder=match_finder, indexer=indexer)
inference_model2 = InferenceModel(model2, match_finder=match_finder, indexer=indexer)
inference_model3 = InferenceModel(model3, match_finder=match_finder, indexer=indexer)
    
for img, label in query_loader:
    indices1, distances1 = inference_model.get_nearest_neighbors(img, k=1)
    indices2, distances2 = inference_model.get_nearest_neighbors(img, k=1)
    indices3, distances3 = inference_model.get_nearest_neighbors(img, k=1)
    nearest_index = # index of smallest distance among distances1, distances2, distances3
    predicted_label = gallery_data[nearest_index]

@KevinMusgrave
Copy link
Owner

That is one approach. Another way is to concatenate the normalized embeddings:

import torch.nn.functional as F

class EnsembleModel(torch.nn.Module):
    def __init__(self, models):
        self.models = torch.nn.ModuleList(models)

    def forward(self, x):
        return torch.cat([F.normalize(m(x)) for m in self.models], dim=1)

ensemble = EnsembleModel([model1, model2, model3])
inference_model = InferenceModel(ensemble, match_finder=match_finder, indexer=indexer)

@KennyTC
Copy link
Author

KennyTC commented Jul 17, 2021

You are really a legend. Thank you so much.

@KevinMusgrave
Copy link
Owner

You're welcome

@KennyTC
Copy link
Author

KennyTC commented Jul 28, 2021

One follow up question:
In case I have the following model where train and val classes are disjoint.

  • train [0, 1, 2, 3, 4, 5], val [6, 7, 8]

then, how the accuracy on val is calculated?

Normally, the accuracy calculation is something like

predicted = model(X)
accuracy_score(label, predicted)

but here the predicted will only get values [0,1,2,3,4,5] because it is the output of the model, while label is [6,7,8]?

@KevinMusgrave
Copy link
Owner

You can make train+val the reference set, so the training set will be like "noise" that you don't want to retrieve.

If you're using the trainers module, the default is to just use the 'val' set. To use train+val:

splits_to_eval = [('val', ['train', 'val'])]
end_of_epoch_hook = hooks.end_of_epoch_hook(tester, 
                                            dataset_dict, 
                                            model_folder, 
                                            test_interval,
                                            patience,
                                            splits_to_eval = splits_to_eval)

@KevinMusgrave
Copy link
Owner

but here the predicted will only get values [0,1,2,3,4,5] because it is the output of the model, while label is [6,7,8]?

The accuracy is computed using k-nearest-neighbors in the embedding space. The class logits are not used.

@KennyTC
Copy link
Author

KennyTC commented Jul 29, 2021

You can make train+val the reference set, so the training set will be like "noise" that you don't want to retrieve.

Many thanks. It is actually the problem that I am facing now, which is the model is very good on train and val, but not generalize well on test set. Maybe train and val is easier than the test set.

The accuracy is computed using k-nearest-neighbors in the embedding space. The class logits are not used.

I see. Can you explain more about the way accuracy is calculated (or link or code is fine)? I just want to make sure to understand things clearly.

Besides, does the following training's log means this?

  • running knn with k = 19, ie.
  • for each point in my val, there will be nearest 19 points
  • each point will have embedding 2048 -->totally the embedding space will have len(val)*19*2048 vectors.
  • then these vectors will be clustered into 27 (in this case 27 classes).

image

@KevinMusgrave
Copy link
Owner

I see. Can you explain more about the way accuracy is calculated (or link or code is fine)? I just want to make sure to understand things clearly.

See this link

  • running knn with k = 19, ie.

Yes

  • for each point in my val, there will be nearest 19 points

Yes

  • each point will have embedding 2048 -->totally the embedding space will have len(val)*19*2048 vectors.

No, it will find the nearest 19 vectors for each point in val. The total number of vectors retrieved is therefore len(val)*19. Each vector has dimensionality 2048.

  • then these vectors will be clustered into 27 (in this case 27 classes).

Yes, but the clustering result is used only for the AMI and NMI accuracy metrics, not the other metrics like precision

@KennyTC
Copy link
Author

KennyTC commented Jul 30, 2021

No, it will find the nearest 19 vectors for each point in val. The total number of vectors retrieved is therefore len(val)*19. Each vector has dimensionality 2048.

Yes, I see. I explained it wrongly.

It seems that when I used both train+val as reference set, then there is no longer the accuracy on train set, i.e, the following code get errors on accuracies = hooks.get_accuracy_history(tester, "train", return_all_metrics=True). Why is like this? I still want to draw both accuracy on train and val to judge whether my model is overfit or not.

for c,ds in zip(['r', 'b'], ["train", 'val']):
    accuracies = hooks.get_accuracy_history(tester, ds, return_all_metrics=True)
    plt.plot(accuracies['epoch'], accuracies['precision_at_1_level0'], '{}x-'.format(c), label=ds)
plt.legend()
plt.title("Adjusted Mutual Info")

@KevinMusgrave
Copy link
Owner

Set splits_to_eval:

reference = ['train', 'val']
splits_to_eval = [('train', reference), ('val', reference)]

See the documentation: https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#testing-splits

@KennyTC
Copy link
Author

KennyTC commented Feb 25, 2022

Just to confirm

Yes, but the clustering result is used only for the AMI and NMI accuracy metrics, not the other metrics like precision

So, if I use precision_at_1 as my main metric (i.e. hooks = logging_presets.get_hook_container(record_keeper, primary_metric="precision_at_1"), then the above logs (i.e "running knn with k = 19, ...each point will have embedding 2048 ...
then these vectors will be clustered into 27 (in this case 27 classes") will not be relevant, right?

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Feb 25, 2022

Correct. And if you don't want to compute those clustering metrics at all, then you can set the exclude parameter of AccuracyCalculator:

ac = AccuracyCalculator(exclude=("AMI", "NMI"))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question A general question about the library
Projects
None yet
Development

No branches or pull requests

2 participants