In [10]:
import torch
import torch.nn as nn
import numpy as np
import math

In [11]:
# sinusoidal position embedding
class PositionalEmbedding(nn.Module):
    def __init__(self,max_seq_len: int,embed_model_dim: int):
        """
        Args:
            max_seq_len(int) : length of input sequence
            embed_model_dim(int) : dimension of embedding
        """
        super(PositionalEmbedding,self).__init__()
        self.embed_dim = torch.tensor(embed_model_dim).float()
        #self.embed_dim = embed_model_dim

        pe = torch.zeros(max_seq_len,embed_model_dim)
        for pos in range(max_seq_len):
            for i in range(0,embed_model_dim,2):
                pe[pos,i] = torch.sin(torch.tensor(pos/(10000**(2*i/embed_model_dim))))
                pe[pos,i+1] = torch.cos(torch.tensor(pos/(10000**(2*i/embed_model_dim))))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe',pe)

    def forward(self,x):
        """
        Args:
            x(torch.Tensor) : input tensor (B,Length,Dim)
        Returns:
            torch.Tensor : input tensor + positional embedding
        """
        x = x + torch.sqrt(self.embed_dim) # make embeddings relatively larger
        #x = x + math.sqrt(self.embed_dim)
        seq_len = x.size(1)
        x = x + torch.autograd.Variable(self.pe[:,:seq_len],requires_grad=False)
        return x

test_input = torch.zeros(1,3,10)
print(test_input)
posembed = PositionalEmbedding(3,10)
test_output = posembed(test_input)
print(test_output)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
tensor([[[3.1623, 4.1623, 3.1623, 4.1623, 3.1623, 4.1623, 3.1623, 4.1623,
          3.1623, 4.1623],
         [4.0037, 3.7026, 3.1874, 4.1620, 3.1629, 4.1623, 3.1623, 4.1623,
          3.1623, 4.1623],
         [4.0716, 2.7461, 3.2125, 4.1610, 3.1635, 4.1623, 3.1623, 4.1623,
          3.1623, 4.1623]]])


In [12]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self,dim: int,num_heads :int=8,qkv_bias: bool=True,dropout: float=0.,
                 is_causal: bool=False,quiet_attention: bool=False):
        """
        Args:
            dim (int): 埋め込み次元数
            num_heads (int): MultiHeadAttentionのHead数
            qkv_bias (bool): MultiHeadAttentionの埋め込み全結合層にbiasを付けるかどうか
            dropout (float): ドロップアウト確率
            is_causal (bool): Trueの場合、masked multi-head attentionを行う
            quiet_attention (bool): Trueの場合、softmaxの分母に1を足す
        Note:
            quiet attentionのreference
            https://www.evanmiller.org/attention-is-off-by-one.html
        """
        super().__init__()
        
        self.is_causal = is_causal
        self.quiet_attention = quiet_attention
        self.num_heads = num_heads
        assert dim % num_heads == 0, f"The hidden size {dim} is not a multiple of the number of head attention"
        self.hidden_dim = dim
        self.head_dim = dim // num_heads
        
        self.query = nn.Linear(dim,dim,bias=qkv_bias)
        self.key = nn.Linear(dim,dim,bias=qkv_bias)
        self.value = nn.Linear(dim,dim,bias=qkv_bias)
        
        self.dropout = nn.Dropout(p=dropout)
        self.projection = nn.Sequential(
            nn.Linear(dim,dim),
            nn.Dropout(p=dropout),
        )
    
    def forward(self,x,mask=False):
        """
        Args:
            x (torch.Tensor): input tensor (B,Length,Dim)
            mask (bool): Trueの場合、masked multi-head attentionを行う
        """
        batch_size,num_patches,_ = x.size()
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        
        # マルチヘッドに分割
        #multihead_qkv_shape = q.size()[:-1] + (self.num_heads, self.head_dim)
        multihead_qkv_shape = torch.Size([batch_size, num_patches, self.num_heads, self.head_dim])
        qs = q.view(multihead_qkv_shape)
        qs = qs.permute(0, 2, 1, 3)
        ks = k.view(multihead_qkv_shape)
        ks = ks.permute(0, 2, 1, 3)
        ks_T = ks.transpose(2,3)
        vs = v.view(multihead_qkv_shape)
        vs = vs.permute(0, 2, 1, 3)
        
        scaled_dot_product = qs@ks_T / np.sqrt(self.head_dim)

        # masked multi-head attention
        if self.is_causal:
            mask = nn.Transformer.generate_square_subsequent_mask(num_patches,device=x.device)
            scaled_dot_product = scaled_dot_product + mask

        if self.quiet_attention:
            self_attention = _softmax_one(scaled_dot_product,dim=-1)
        else:
            self_attention = nn.functional.softmax(scaled_dot_product,dim=-1)
        self_attention = self.dropout(self_attention) # 実装上はあるっぽいけど何なんこれ
        
        context_layer = self_attention@vs
        #context_layer = context_layer.transpose(1,2).reshape(batch_size,num_patchs,self.hidden_dim)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous().reshape(batch_size,num_patches,self.hidden_dim)
        out = self.projection(context_layer)
        #out = context_layer
        
        return out

def _softmax_one(x,dim=-1):
    """ https://www.evanmiller.org/attention-is-off-by-one.html の実装
    Args:
        x (torch.Tensor):
        dim (int, optional): softmaxを取る次元. Defaults to -1.
    Returns:
        torch.Tensor: softmaxを取った後のテンソル
    """
    x = x - x.max(dim=dim, keepdim=True).values # subtract the max for stability
    exp_x = torch.exp(x)
    return exp_x / (1+exp_x.sum(dim=dim,keepdim=True))

class FeedForward(nn.Module):
    def __init__(self,dim: int,hidden_dim: int=768*4,activation=nn.GELU(),dropout: float=0.):
        """
        Args:
            dim (int): 埋め込み次元数
            hidden_dim (int): FeedForward Networkの隠れ層次元数
            activation (torch.nn.modules.activation): pytorchの活性化関数
            dropout (float): ドロップアウト確率
        """
        super().__init__()
        self.linear1 = nn.Linear(dim,hidden_dim)
        self.linear2 = nn.Linear(hidden_dim,dim)
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self,x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        
        return x

In [13]:
x = torch.tensor([[
    [1,1,1,1,1,1,1,1,1,1],
    [1,1,1,1,1,1,1,1,1,1],
    [2,2,2,2,2,2,2,2,2,2],
    [2,2,2,2,2,2,2,2,2,2],
]]).float()

batch_size,num_patches,_ = x.size()
query = nn.Linear(10,10)
key = nn.Linear(10,10)
value = nn.Linear(10,10)

q = query(x)
k = key(x)
v = value(x)

num_heads = 2
head_dim = 10 // num_heads

# マルチヘッドに分割
#multihead_qkv_shape = q.size()[:-1] + (self.num_heads, self.head_dim)
multihead_qkv_shape = torch.Size([1, 4, num_heads, head_dim])
qs = q.view(multihead_qkv_shape) # (b, n_patch, n_heads, head_dim)
qs = qs.permute(0, 2, 1, 3)
ks = k.view(multihead_qkv_shape)
ks = ks.permute(0, 2, 1, 3)
ks_T = ks.transpose(2,3)
vs = v.view(multihead_qkv_shape)
vs = vs.permute(0, 2, 1, 3)

scaled_dot_product = qs@ks_T / np.sqrt(head_dim) # (b, n_heads, n_patch, n_patch)
print(scaled_dot_product.shape)

# masked multi-head attention
if True:
    mask = nn.Transformer.generate_square_subsequent_mask(num_patches,device=x.device)
    print(mask.shape)
    scaled_dot_product = scaled_dot_product + mask
    print(scaled_dot_product)

self_attention = nn.functional.softmax(scaled_dot_product,dim=-1)
print(self_attention.shape)
print(self_attention)

context_layer = self_attention@vs
#context_layer = context_layer.transpose(1,2).reshape(batch_size,num_patchs,self.hidden_dim)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous().reshape(batch_size,num_patches,10)
#out = self.projection(context_layer)
out = context_layer

print(out)

torch.Size([1, 2, 4, 4])
torch.Size([4, 4])
tensor([[[[ 0.5834,    -inf,    -inf,    -inf],
          [ 0.5834,  0.5834,    -inf,    -inf],
          [ 1.3026,  1.3026,  2.5177,    -inf],
          [ 1.3026,  1.3026,  2.5177,  2.5177]],

         [[-0.2067,    -inf,    -inf,    -inf],
          [-0.2067, -0.2067,    -inf,    -inf],
          [-0.3710, -0.3710, -0.4832,    -inf],
          [-0.3710, -0.3710, -0.4832, -0.4832]]]], grad_fn=<AddBackward0>)
torch.Size([1, 2, 4, 4])
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.0000, 0.0000],
          [0.1862, 0.1862, 0.6276, 0.0000],
          [0.1144, 0.1144, 0.3856, 0.3856]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.0000, 0.0000],
          [0.3456, 0.3456, 0.3089, 0.0000],
          [0.2640, 0.2640, 0.2360, 0.2360]]]], grad_fn=<SoftmaxBackward0>)
tensor([[[ 0.0877, -0.1831,  0.2694, -0.2178, -0.2970, -0.6170, -1.1308,
           1.1616, -0.5541, -0.4736],
         [ 0.0877, -0.

In [14]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float):
        """
        Args:
            embed_dim (int): 埋め込み次元数
            num_heads (int): MultiHeadAttentionのHead数
            ff_hidden_dim (int): FeedForward Networkの隠れ層次元数
            dropout (float): ドロップアウト確率
        """
        super().__init__()
        self.mhsa = MultiHeadSelfAttention(dim=embed_dim,num_heads=num_heads,dropout=dropout)
        self.ff = FeedForward(dim=embed_dim,hidden_dim=ff_hidden_dim,dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
    
    def forward(self,x):
        z = self.mhsa(x)
        z = self.norm1(z)
        x = x + z
        z = self.ff(x)
        z = self.norm2(z)
        x = x + z
        return x

class TransformerEncoder(nn.Module):
    def __init__(self,max_seq_len: int, vocab_size: int, embed_dim: int, 
                 ff_hidden_dim: int, num_blocks: int, num_heads: int, dropout: float):
        """
        Args:
            max_seq_len (int): 入力系列の最大長
            vocab_size (int): 語彙数
            embed_dim (int): 埋め込み次元数
            ff_hidden_dim (int): FeedForward Networkの隠れ層次元数
            num_blocks (int): TransformerBlockの数
            num_heads (int): MultiHeadAttentionのHead数
            dropout (float): ドロップアウト確率
        """
        super(TransformerEncoder,self).__init__()
        
        self.embedding_layer = nn.Embedding(vocab_size,embed_dim)
        self.positional_embedding = PositionalEmbedding(max_seq_len,embed_dim)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim=embed_dim,num_heads=num_heads,ff_hidden_dim=ff_hidden_dim,dropout=dropout)
            for _ in range(num_blocks)
        ])
    
    def forward(self,token_ids):
        """
        Args:
            token_ids (torch.Tensor): 入力トークンID (batch_size,seq_len)
        Returns:
            torch.Tensor: TransformerEncoderの出力 (batch_size,seq_len,embed_dim)
        """
        x = self.embedding_layer(token_ids)
        x = self.positional_embedding(x)
        for block in self.blocks:
            x = block(x)
        return x
    
test_input = torch.tensor([[1,2,3,4,5,6,7,8,9,10]])
print(test_input.shape)
transformer_encoder = TransformerEncoder(max_seq_len=10,vocab_size=30,embed_dim=10,ff_hidden_dim=40,num_blocks=2,num_heads=2,dropout=0.1)
test_output = transformer_encoder(test_input)
print(test_output.shape)

torch.Size([1, 10])
torch.Size([1, 10, 10])


In [15]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self,dim: int,num_heads: int=8,dropout: float=0.,quiet_attention: bool=False, proj: bool=False):
        """
        Args:
            dim (int): 埋め込み次元数
            num_heads (int): MultiHeadAttentionのHead数
            dropout (float): ドロップアウト確率
            quiet_attention (bool): Trueの場合、softmaxの分母に1を足す
        """
        super().__init__()
        
        self.quiet_attention = quiet_attention
        self.num_heads = num_heads
        assert dim % num_heads == 0, f"The hidden size {dim} is not a multiple of the number of head attention"
        self.hidden_dim = dim
        self.head_dim = dim // num_heads
        
        self.dropout = nn.Dropout(p=dropout)
        if proj:
            self.projection = nn.Sequential(
                nn.Linear(dim,dim),
                nn.Dropout(p=dropout),
            )
        else:
            self.projection = nn.Sequential(
                nn.Identity(),
                nn.Dropout(p=dropout),
            )
    
    def forward(self,query,key,value):
        """ 
        Args:
            query (torch.Tensor): query (batch_size,query_len,hidden_dim)
            key (torch.Tensor): key (batch_size,key_len,hidden_dim)
            value (torch.Tensor): value (batch_size,value_len,hidden_dim)
        """ 
        # マルチヘッドに分割
        batch_size,query_len,_ = query.size()
        batch_size,key_len,_ = key.size()
        batch_size,value_len,_ = value.size()

        multihead_q_shape = torch.Size([batch_size, query_len, self.num_heads, self.head_dim])
        multihead_k_shape = torch.Size([batch_size, key_len, self.num_heads, self.head_dim])
        multihead_v_shape = torch.Size([batch_size, value_len, self.num_heads, self.head_dim])

        qs = query.view(multihead_q_shape)
        qs = qs.permute(0, 2, 1, 3)
        ks = key.view(multihead_k_shape)
        ks = ks.permute(0, 2, 1, 3)
        ks_T = ks.transpose(2,3)
        vs = value.view(multihead_v_shape)
        vs = vs.permute(0, 2, 1, 3)

        scaled_dot_product = qs@ks_T / np.sqrt(self.head_dim)
        if self.quiet_attention:
            cross_attention = _softmax_one(scaled_dot_product,dim=-1)
        else:
            cross_attention = nn.functional.softmax(scaled_dot_product,dim=-1)
        
        context_layer = cross_attention@vs
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous().reshape(batch_size,query_len,self.hidden_dim)
        out = self.projection(context_layer)

        return out

In [16]:
# causal attention test 
test_input = torch.tensor([[
    [0,0,0,0,0],
    [0,0,0,0,0],
    [1,1,1,1,1],
    [1,1,1,1,1],
]]).float()
#print(test_input.shape)

# normal attention
attn = MultiHeadSelfAttention(dim=5,num_heads=1,dropout=0.0,quiet_attention=False)
test_output = attn(test_input)
print(test_output)

# causal attention
attn = MultiHeadSelfAttention(dim=5,num_heads=1,dropout=0.0,is_causal=True,quiet_attention=False)
test_output = attn(test_input)
print(test_output)

tensor([[[ 0.4126, -0.1694, -0.3108,  0.4688,  0.2641],
         [ 0.4126, -0.1694, -0.3108,  0.4688,  0.2641],
         [ 0.4031, -0.1579, -0.3032,  0.4626,  0.2592],
         [ 0.4031, -0.1579, -0.3032,  0.4626,  0.2592]]],
       grad_fn=<ViewBackward0>)
tensor([[[-0.2384,  0.0984,  0.4646,  0.5792,  0.3071],
         [-0.2384,  0.0984,  0.4646,  0.5792,  0.3071],
         [-0.1494,  0.1263,  0.5072,  0.4239,  0.2393],
         [-0.1059,  0.1399,  0.5280,  0.3479,  0.2061]]],
       grad_fn=<ViewBackward0>)


In [17]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float):
        """
        Args:
            embed_dim (int): 埋め込み次元数
            num_heads (int): MultiHeadAttentionのHead数
            ff_hidden_dim (int): FeedForward Networkの隠れ層次元数
            dropout (float): ドロップアウト確率
        """
        super(TransformerDecoderBlock,self).__init__()

        self.mmhsa = MultiHeadSelfAttention(dim=embed_dim,num_heads=num_heads,dropout=dropout,is_causal=True)
        self.mhca = MultiHeadCrossAttention(dim=embed_dim,num_heads=num_heads,dropout=dropout)
        self.ff = FeedForward(dim=embed_dim,hidden_dim=ff_hidden_dim,dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)

    def forward(self,x,encoder_output):
        """
        Args:
            x (torch.Tensor): 入力トークン (batch_size,seq_len)
            encoder_output (torch.Tensor): TransformerEncoderの出力 (batch_size,seq_len,embed_dim)
        """
        z = self.mmhsa(x)
        z = self.norm1(z)
        x = x + z
        z = self.mhca(query=x,key=encoder_output,value=encoder_output)
        z = self.norm2(z)
        x = x + z
        z = self.ff(x)
        z = self.norm3(z)
        x = x + z
        return x

In [18]:
class TransformerDecoder(nn.Module):
    def __init__(self,max_seq_len: int, vocab_size: int, embed_dim: int, 
                 ff_hidden_dim: int, num_blocks: int, num_heads: int, dropout: float):
        """
        Args:
            max_seq_len (int): 入力系列の最大長
            vocab_size (int): 語彙数
            embed_dim (int): 埋め込み次元数
            ff_hidden_dim (int): FeedForward Networkの隠れ層次元数
            num_blocks (int): TransformerBlockの数
            num_heads (int): MultiHeadAttentionのHead数
            dropout (float): ドロップアウト確率
        """
        super(TransformerDecoder,self).__init__()
        
        self.embedding_layer = nn.Embedding(vocab_size,embed_dim)
        self.positional_embedding = PositionalEmbedding(max_seq_len,embed_dim)
        self.dropout = nn.Dropout(p=dropout)

        self.blocks = nn.ModuleList([
            TransformerDecoderBlock(embed_dim=embed_dim,num_heads=num_heads,ff_hidden_dim=ff_hidden_dim,dropout=dropout)
            for _ in range(num_blocks)
        ])

        self.head = nn.Linear(embed_dim,vocab_size)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self,x,encoder_output):
        """
        Args:
            x (torch.Tensor): 入力トークン (batch_size,seq_len)
            encoder_output (torch.Tensor): TransformerEncoderの出力 (batch_size,seq_len,embed_dim)
        Returns:
            torch.Tensor: TransformerDecoderの出力 (batch_size,seq_len,vocab_size)
        """
        x = self.embedding_layer(x)
        x = self.positional_embedding(x)
        x = self.dropout(x)
        for block in self.blocks:
            x = block(x,encoder_output)
        x = self.head(x)
        x = self.softmax(x)
        return x

In [19]:
"""
TODO transformerの実装
実装したところで何に使うのかは知らんけど

class Transformer(nn.Module):
    def __init__(self,vocab_size: int, embed_dim: int, max_seq_len: int,
                 num_encoder_blocks: int, num_decoder_blocks: int, embed_dim: int,
                 num_heads: int, ff_hidden_dim: int, dropout: float)
"""

'\nTODO transformerの実装\n実装したところで何に使うのかは知らんけど\n\nclass Transformer(nn.Module):\n    def __init__(self,vocab_size: int, embed_dim: int, max_seq_len: int,\n                 num_encoder_blocks: int, num_decoder_blocks: int, embed_dim: int,\n                 num_heads: int, ff_hidden_dim: int, dropout: float)\n'

In [20]:
class GPTBlocks(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float):
        """
        Args:
            embed_dim (int): 埋め込み次元数
            num_heads (int): MultiHeadAttentionのHead数
            ff_hidden_dim (int): FeedForward Networkの隠れ層次元数
            dropout (float): ドロップアウト確率
        """
        super(GPTBlocks,self).__init__()

        self.mmhsa = MultiHeadSelfAttention(dim=embed_dim,num_heads=num_heads,dropout=dropout,is_causal=True)
        self.ff = FeedForward(dim=embed_dim,hidden_dim=ff_hidden_dim,dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self,x):
        z = self.mmhsa(x)
        z = self.norm1(z)
        x = x + z
        z = self.ff(x)
        z = self.norm2(z)
        x = x + z
        return x

class GPT(nn.Module):
    def __init__(self, max_seq_len: int, vocab_size: int, num_blocks: int,
                 embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float):
        """
        Args:
            max_seq_len (int): 入力系列の最大長
            vocab_size (int): 語彙数
            embed_dim (int): 埋め込み次元数
            num_blocks (int): TransformerBlockの数
            num_heads (int): MultiHeadAttentionのHead数
            ff_hidden_dim (int): FeedForward Networkの隠れ層次元数
            dropout (float): ドロップアウト確率
        """
        super(GPT,self).__init__()

        self.embedding_layer = nn.Embedding(vocab_size,embed_dim)
        self.positional_embedding = PositionalEmbedding(max_seq_len,embed_dim)
        self.dropout = nn.Dropout(p=dropout)

        self.blocks = nn.ModuleList([
            GPTBlocks(embed_dim=embed_dim,num_heads=num_heads,ff_hidden_dim=ff_hidden_dim,dropout=dropout)
            for _ in range(num_blocks)
        ])
        
        self.head = nn.Linear(embed_dim,vocab_size) # embeddingの逆行列を使う方法もある
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self,x,target=None):
        """
        Args:
            x (torch.Tensor): 入力トークン (batch_size,seq_len)
            target (torch.Tensor): 教師トークン (batch_size,seq_len)
        Returns:
            tuple[torch.Tensor, torch.Tensor]: GPTの出力 (batch_size,seq_len,vocab_size), 損失 (1,)
        """
        x = self.embedding_layer(x)
        x = self.positional_embedding(x)
        x = self.dropout(x)
        for block in self.blocks:
            x = block(x)
        x = self.head(x)
        x = self.softmax(x)

        if target is not None:
            loss = nn.functional.cross_entropy(x.view(-1,x.size(-1)),target.view(-1))
            return x,loss
        else:
            return x,None
        


In [21]:
# GPT test
test_input = torch.tensor([[1,2,3,4,5,6,7,8,9,10]])
print(test_input.shape)
gpt = GPT(max_seq_len=10,vocab_size=30,num_blocks=2,embed_dim=10,num_heads=2,ff_hidden_dim=40,dropout=0.1)
test_output = gpt(test_input)
print(test_output.shape)
print(test_output.argmax(dim=-1))

torch.Size([1, 10])
torch.Size([1, 10, 30])
tensor([[20, 20,  4, 20,  4, 20, 20, 20, 20, 20]])


In [22]:
# 別の用事のやつ
import torch
import torch.nn as nn

embed = nn.Embedding(5,4)
test_input = torch.tensor([[0,1,2,3,4]])
test_output = embed(test_input)
print(test_output.shape)
print(test_output)

# embedの疑似逆行列をかけて元の値に戻す
weight = embed.weight.data.detach()
weight_inv = (weight.T@weight).inverse()@weight.T
print(weight_inv.shape)

test = weight_inv@test_output#@weight_inv

print(test)

torch.Size([1, 5, 4])
tensor([[[ 0.8219, -0.9072, -0.7282,  2.3748],
         [ 1.0498,  0.5476, -0.4132,  0.2733],
         [-0.0970,  0.2545, -1.4792, -0.9674],
         [ 0.9543,  1.2970,  0.2506, -1.0072],
         [ 0.0949,  1.2535,  0.2068,  0.5078]]], grad_fn=<EmbeddingBackward0>)
torch.Size([4, 5])
tensor([[[ 1.0000e+00, -9.2654e-08, -5.3197e-08,  3.7330e-07],
         [-5.3914e-08,  1.0000e+00,  2.1766e-08, -2.6786e-07],
         [-1.5445e-08, -4.7437e-08,  1.0000e+00,  6.1151e-08],
         [-5.1538e-08,  2.5637e-08,  4.2161e-09,  1.0000e+00]]],
       grad_fn=<CloneBackward0>)
