In [1]:
import torch.sparse as sparse
import numpy.random as random
import numpy as np
from torchviz import make_dot
import torch
from d2l import torch as d2l

In [2]:
def synthetic_data(w, b, num_examples):  #@save
    """Generate y = Xw + b + noise."""
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w)
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4, 4.9, -1, 3.2, 0.35])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

In [3]:
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # The examples are read at random, in no particular order
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i:min(i +
                                                   batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

In [4]:
def forward(features, W_hidden, w_out):
    return torch.matmul(torch.matmul(features, W_hidden), w_out)

In [5]:
def forward_sparse(features, W_hidden, w_out):
    # import pdb; pdb.set_trace()
    hidden = sparse.mm(W_hidden, features.t())
    return sparse.mm(w_out.t(), hidden)
    # sparse.mm is very specific which matrices it gets. The first needs to sparse, the second needs to be strided/dense. So one needs to do a bunch of .t().

In [6]:
def squared_loss(y_hat, y):  #@save
    """Squared loss."""
    return (y_hat - y.reshape(y_hat.shape))**2 / 2

In [7]:
def top_kast_forward(w, kast = 0.5):
    """Selects 50% largest coefficients, not exactly like in paper, but for testing it should be fine."""
    # import pdb; pdb.set_trace()
    threshold = np.quantile(w.detach().numpy().reshape(-1), 0.5)
    mask = w < threshold
    w[mask] = 0
    return w.to_sparse().requires_grad_(True)

In [8]:
def compute_mask(w, kast = 0.5):
    # import pdb; pdb.set_trace()
    threshold = np.quantile(w.detach().numpy().reshape(-1), 0.5)
    mask = w < threshold
    return mask

In [29]:
def sgd(params, lr, batch_size):  #@save
    """Minibatch stochastic gradient descent."""
    with torch.no_grad():
        for param in params:
            # import pdb; pdb.set_trace()
            param -= lr * param.grad / batch_size
            param.grad.zero_()

In [34]:
lr = 0.03
num_epochs = 10
net = forward_sparse
loss = squared_loss
batch_size = 10

W_hidden = torch.normal(0, 0.01, size=(6, 6), requires_grad = True).to_sparse()
w_out = torch.normal(0, 0.01, size=(6, 1), requires_grad = True).to_sparse()
W_hidden.retain_grad()
w_out.retain_grad()


for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        # Make sparse forward pass
        # mask_hidden = compute_mask(W_hidden)
        # mask_out = compute_mask(w_out)
        # import pdb; pdb.set_trace()
        y_hat = net(X, W_hidden, w_out)
        l = loss(y_hat, y)  # Minibatch loss in `X` and `y`
        l.sum().backward()
        sgd([W_hidden, w_out], lr, batch_size)  # Update parameters using their gradient
    with torch.no_grad():
        train_l = loss(net(features, W_hidden, w_out), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

epoch 1, loss 0.000059
epoch 2, loss 0.000054
epoch 3, loss 0.000056
epoch 4, loss 0.000053
epoch 5, loss 0.000055
epoch 6, loss 0.000054
epoch 7, loss 0.000058
epoch 8, loss 0.000060
epoch 9, loss 0.000055
epoch 10, loss 0.000054
