Skip to content

Commit

Permalink
Updated the documentation for multiclass nll loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sumitpai committed Apr 15, 2019
1 parent b9fe5d1 commit 6211e5b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 7 deletions.
21 changes: 17 additions & 4 deletions ampligraph/latent_features/loss_functions.py
Expand Up @@ -438,14 +438,27 @@ def _apply(self, scores_pos, scores_neg):

@register_loss("multiclass_nll", [], {'require_same_size_pos_neg':False})
class NLLMulticlass(Loss):
""" Multiclass NLL Loss
""" Multiclass NLL Loss.
Introduced in :cite: `chen2015` where both the subject and objects are corrupted (to use it in this way pass corrupt_sides = ['s', 'o'] to embedding_model_params) .
This loss was re-engineered in :cite: `kadlecBK17` where only the object was corrupted to get improved performance. (to use it in this way pass corrupt_sides = 'o' to embedding_model_params)
Introduced in :cite:`chen2015` where both the subject and objects are corrupted (to use it in this way pass corrupt_sides = ['s', 'o'] to embedding_model_params) .
..math::
This loss was re-engineered in :cite:`kadlecBK17` where only the object was corrupted to get improved performance (to use it in this way pass corrupt_sides = 'o' to embedding_model_params).
.. math::
\mathcal{L(X)} = -\sum_{x_{e_1,e_2,r_k} \in X} log\,p(e_2|e_1,r_k) -\sum_{x_{e_1,e_2,r_k} \in X} log\,p(e_1|r_k, e_2)
Examples
--------
>>> import numpy as np
>>> from ampligraph.latent_features import TransE
>>> model = TransE(batches_count=1, seed=555, epochs=20, k=10,
>>> embedding_model_params={'corrupt_side':['s', 'o']},
>>> loss='multiclass_nll', loss_params={})
"""
def __init__(self, eta, loss_params={}, verbose=False):
"""Initialize Loss
Expand Down
4 changes: 1 addition & 3 deletions ampligraph/latent_features/models.py
Expand Up @@ -1210,11 +1210,9 @@ def __init__(self,
- **'normalize_ent_emb'** (bool): flag to indicate whether to normalize entity embeddings after each batch update (default: False).
- **negative_corruption_entities** : entities to be used for generation of corruptions while training. It can take the following values : ``all`` (default: all entities), ``batch`` (entities present in each batch), list of entities or an int (which indicates how many entities that should be used for corruption generation).
- **corrupt_sides** : Specifies how to generate corruptions for training. Takes values `s`, `o`, `s+o` or any combination passed as a list
Example: ``embedding_model_params={'norm': 1, 'normalize_ent_emb': False}``
optimizer : string
The optimizer used to minimize the loss function. Choose between 'sgd',
'adagrad', 'adam', 'momentum'.
Expand Down
1 change: 1 addition & 0 deletions docs/ampligraph.latent_features.rst
Expand Up @@ -100,6 +100,7 @@ and they can be thus used :ref:`during model selection <eval>`.
NLLLoss
AbsoluteMarginLoss
SelfAdversarialLoss
NLLMulticlass

.. _ref-reg:

Expand Down
20 changes: 20 additions & 0 deletions docs/generated/ampligraph.latent_features.NLLMulticlass.rst
@@ -0,0 +1,20 @@
NLLMulticlass
==================================

.. currentmodule:: ampligraph.latent_features

.. autoclass:: NLLMulticlass





.. rubric:: Methods

.. autosummary::

~NLLMulticlass.__init__

.. automethod:: NLLMulticlass.__init__


0 comments on commit 6211e5b

Please sign in to comment.