In [1]:
import numpy as np
import matplotlib.pyplot as plt

# Embedding layerの実装

In [7]:
class Embedding:
    def __init__(self, W):
        self.params = [W]
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W, = self.params
        self.idx = idx
        out = W[idx]
        return out

    def backward(self, dout):
        dW, = self.grads
        dW[...] = 0

        for i, word_idx in enumerate(self.idx):
            dW[word_idx] += dout[i] # もしくは np.add.at(dw, word_idx, dout)
            return None
            

In [6]:
# ミュータブルの確認

a = np.array([1, 2])
list_a = [a]

a, = list_a

a[...] = 0


list_a

[array([0, 0])]

In [8]:
class EmbeddingDot:
    def __init__(self, W):
        self.embed = Embedding(W)
        self.params = self.embed.params
        self.cache = None

    def forward(self, h, idx):
        target_W = self.embed.forward(idx)
        out = np.sum(target_W * h, axis=1)

        self.cache = (h, target_W)

    def backward(self, dout):
        h, target_W = self.cache
        dout = dout.reshape(dout.shape[0], 1)
        dtarget_W = dout * h
        self.embed.backward(dtarget_W)
        dh = dout * target_W
        return dh