In [3]:
import numpy as np
from backpropagation import Matrix, zero_grad


def forward_pass(context, label_word, context_weight, softmax_weight, d, method="cbow"):  # data in the form of list of context words
    # context word: batch_size x |V|   weights: |V| x d
    loss = 0
    if method == "cbow":
        net = Matrix(np.zeros(label_word.shape[0], d))
        for word in context:
            net += word @ context_weight
        net /= len(context)
        out = (net @ softmax_weight).softmax()
        loss = out.cross_entropy(label_word)
    
    elif method == "skipgram":
        net = label_word @ context_weight
        losses = []
        for word in context:
            out = (net @ softmax_weight).softmax()
            context_loss = out.cross_entropy()
            losses.append(context_loss)
        loss = sum(losses) / len(context)
    
    else:
        print("Not a valid method!")
    return loss


# data shape: N x C x |V|
# data has first C-1 elements as context and last element as label word
def train(data, weights, batch_size=4, num_epochs=100, lr=0.01, dim=100, method="cbow"):
    # context_weights = Matrix(np.random.randn(data.shape[2], dim))
    # softmax_weights = Matrix(np.random.randn(dim, data.shape[2]))
    for epoch in range(num_epochs):
        for i in range(0, data.shape[0], batch_size):
            context = []
            batch = data[i:i+batch_size-1, :, :]
            for w in range(data.shape[1] - 1):
                context.append(Matrix(batch[:, w, :].squeeze(1)))
            label_word = Matrix(batch[:, -1, :].squeeze(1))
            
            loss = forward_pass(context, label_word, weights[0], weights[1], dim)
            order = loss.backprop()
            for weight in weights:
                weight -= lr * weight.grad
            zero_grad(order)

            print(f"Epoch: {epoch+1} | Loss: {loss.val:.5f}")

In [5]:
def word_to_onehot(vocab, word):   # return 1D onehot vector
    onehot = np.zeros((len(vocab),), dtype=np.float32)
    onehot[vocab.index(word)] = 1.0
    return onehot

def onehot_to_word(vocab, onehot):
    return vocab[np.argmax(onehot)]

