In [1]:
import os
import warnings
# warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

In [None]:
x = torch.tensor([9, 4, 0])
embd = nn.Embedding(10+5+2, 3)
embd(x)

In [None]:
# method1
# x是同一个特征的不同level-index, 比如都是tokens, 三个index分别代表tokens[9], tokens[4], tokens[0], vocab_size=17. 那么x的onehot表示如下
x_oh = nn.functional.one_hot(x, 17)
x_oh

In [None]:
# method2
# x是不同特征的level-index, 比如分别代表user_id[9], item_id[4], content_id[0], level_size分别是(10, 5, 2). 那么x的onehot_concat表示如下
torch.tensor([0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0])

In [None]:
# corresponding
y = torch.tensor([9, 14, 15])
nn.functional.one_hot(y, 17)

In [None]:
embd(y)

In [None]:
x = torch.tensor([9, 4, 0])
num_classes = torch.tensor([10, 5, 2])

In [None]:
torch.cat([torch.zeros(1,), torch.cumsum(num_classes, dim=0)[:-1]], dim=0).type(num_classes.dtype)

In [2]:
def offset_multifeatures(input_tensor, num_classes):
    assert len(num_classes) == input_tensor.shape[-1], 'every feature must have its num_class'
    assert torch.all(input_tensor < num_classes), 'index number exceeds or be equal to num_classes. Index number must be smaller than corresponding num_class'
    offsets = torch.cat([torch.zeros(1,), torch.cumsum(num_classes, dim=0)[:-1]], dim=0).type(num_classes.dtype)
    return (input_tensor + offsets).type(input_tensor.dtype)

In [3]:
x = torch.tensor([[[9, 3, 4],
                   [0, 1, 0],
                   [7, 0, 0],
                   [0, 0, 0],
                   [1, 1, 1]],
                  [[9, 3, 4],
                   [0, 1, 0],
                   [7, 0, 0],
                   [0, 0, 0],
                   [1, 1, 1]]])

In [4]:
from torch import Tensor
class MultiIndexEmbedding(nn.Embedding):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, input: Tensor, num_classes: Tensor, flatten=True) -> Tensor:
        input_ = offset_multifeatures(input, num_classes)
        num_embeddings = int(num_classes.sum())
        if self.num_embeddings != num_embeddings:
            warnings.warn(f'arg num_embeddings must be the sum of number of classes of all features. num_embdggins={num_embeddings} set automatically')
            self.num_embeddings = num_embeddings
        embed_ = super(MultiIndexEmbedding, self).forward(input_)
        if flatten:
            return embed_.flatten(start_dim=-2)
        else:
            return embed_

In [7]:
num_classes = torch.tensor([10, 5, 6])
embed = MultiIndexEmbedding(21, 3)
embed(x, num_classes)

tensor([[[[ 0.4871,  1.2604, -0.6868],
          [-0.0262,  0.9785,  0.4463],
          [ 1.4684,  0.3059, -0.2481]],

         [[ 0.2354,  2.5292,  2.2188],
          [ 0.1607,  0.8775, -1.2654],
          [-0.1132, -1.3775, -0.2470]],

         [[ 0.3228,  0.3680, -0.0678],
          [ 0.8700,  0.4354, -0.1996],
          [-0.1132, -1.3775, -0.2470]],

         [[ 0.2354,  2.5292,  2.2188],
          [ 0.8700,  0.4354, -0.1996],
          [-0.1132, -1.3775, -0.2470]],

         [[ 0.9199,  0.5815,  1.5986],
          [ 0.1607,  0.8775, -1.2654],
          [-0.5123, -0.7143, -1.2022]]],


        [[[ 0.4871,  1.2604, -0.6868],
          [-0.0262,  0.9785,  0.4463],
          [ 1.4684,  0.3059, -0.2481]],

         [[ 0.2354,  2.5292,  2.2188],
          [ 0.1607,  0.8775, -1.2654],
          [-0.1132, -1.3775, -0.2470]],

         [[ 0.3228,  0.3680, -0.0678],
          [ 0.8700,  0.4354, -0.1996],
          [-0.1132, -1.3775, -0.2470]],

         [[ 0.2354,  2.5292,  2.2188],
       