Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Add embedding sim
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Feb 21, 2019
1 parent 52b1ce2 commit dc21040
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,16 @@ pip install torch-embed-sim
## Usage

```python
from torch_embed_sim import *
from torch_embed_sim import EmbeddingSim


class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.embed = torch.nn.Embedding(num_embeddings=10, embedding_dim=20)
self.embed_sim = EmbeddingSim(num_embeddings=10)

def forward(self, x):
return self.embed_sim(self.embed(x), self.embed.weight)
```
31 changes: 31 additions & 0 deletions tests/test_embedding_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unittest import TestCase
import torch
import torch.nn as nn
from torch_embed_sim import EmbeddingSim


class TestEmbeddingSim(TestCase):

def test_sample(self):
class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.embed = torch.nn.Embedding(num_embeddings=10, embedding_dim=20)
self.embed_sim = EmbeddingSim(num_embeddings=10, bias=True)

def forward(self, x):
return self.embed_sim(self.embed(x), self.embed.weight)

net = Net()
print(net)
x = torch.randint(0, 10, [10, 100]).type(torch.LongTensor)
y = net(x).argmax(dim=-1)
batch_size, seq_len = x.size()
same_count = 0
for i in range(batch_size):
for j in range(seq_len):
if x[i, j] == y[i, j]:
same_count += 1
self.assertGreater(1.0 * same_count / 1000, 0.99)
EmbeddingSim(num_embeddings=10, bias=False)
1 change: 1 addition & 0 deletions torch_embed_sim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .embedding_sim import *
32 changes: 32 additions & 0 deletions torch_embed_sim/embedding_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['EmbeddingSim']


class EmbeddingSim(nn.Module):

def __init__(self, num_embeddings, bias=True):
super(EmbeddingSim, self).__init__()
self.num_embeddings = num_embeddings
if bias:
self.bias = nn.Parameter(torch.Tensor(num_embeddings))
else:
self.register_parameter('bias', None)
self.reset_parameters()

def reset_parameters(self):
if self.bias is not None:
torch.nn.init.zeros_(self.bias)

def forward(self, x, weight):
y = x.matmul(weight.transpose(1, 0))
if self.bias is not None:
y += self.bias
return F.softmax(y)

def extra_repr(self):
return 'num_embeddings={}, bias={}'.format(
self.num_embeddings, self.bias is not None,
)

0 comments on commit dc21040

Please sign in to comment.