# Feature Embedding Prototype

In [1]:
from torch import nn
import torch

In [24]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, oov_rule="dummy", dist=None):
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.oov_rule = oov_rule # 'zero', 'mean', 'random', 'dummy', "dist"
        self.dist = dist

    def forward(self, x):

        if self.oov_rule == 'dummy':
            try:
                y = self.embedding(x)
            except:
                print('ERROR: Expect dummy input to be considered already for embedding_dim (Add +1 to number of embeddings)')
                raise
        else:
            where = torch.where(x >= self.num_embeddings)
            x[where] = torch.tensor(0, dtype=x.dtype)
            y = self.embedding(x)

            if self.oov_rule == 'zero':
                y[where] = torch.tensor(0., dtype=y.dtype)
            elif self.oov_rule == 'mean':
                mean = torch.mean(self.embedding.weight, axis=-2)
                y[where] = torch.tensor(mean, dtype=y.dtype)
            elif self.oov_rule == 'random':
                values = torch.normal(mean=0., std=1., size=(y.shape))
                y[where] = torch.tensor(values[where], dtype=y.dtype)
            elif self.oov_rule == 'dist':
                y[where] = torch.tensor(self.dist, dtype=y.dtype)
            else:
                raise ValueError('Invalid oov_rule')
        return y


class OneHotEmbedding(nn.Module):
    def __init__(self, num_classes, rule="zero", dist=None):
        super(OneHotEmbedding, self).__init__()
        self.rule = rule  # 'zero', 'one_over_n', 'random', 'dummy'
        self.num_classes = num_classes
        self.dist = dist

    def forward(self, x):
        where = torch.where(x >= self.num_classes)
        x[where] = torch.tensor(0, dtype=x.dtype)
        one_hot = torch.nn.functional.one_hot(x, num_classes=self.num_classes).float()

        if self.rule == 'dummy':
            one_hot = torch.nn.functional.one_hot(x, num_classes=self.num_classes+1).float()
            dummy_class = torch.zeros(self.num_classes+1)
            dummy_class[-1] = 1
            one_hot[where] = dummy_class
        elif self.rule == 'zero':
            one_hot[where] = torch.tensor(0., dtype=one_hot.dtype)
        elif self.rule == 'one_over_n':
            one_hot[where] = torch.tensor(1/self.num_classes, dtype=one_hot.dtype)
        elif self.rule == 'random':
            one_hot[where] = torch.rand(size=one_hot.shape)[where].clone().detach().type(one_hot.dtype)
        elif self.rule == 'dist':
            one_hot[where] = torch.tensor(self.dist, dtype=one_hot.dtype)
        else:
            raise ValueError('Invalid rule')
        return one_hot

In [31]:
embedding_t1 = Embedding(5,6)
onehot_t1 = OneHotEmbedding(5)

In [33]:
t2 = torch.tensor([8])
onehot_t1.forward(t2)

tensor([[0., 0., 0., 0., 0.]])