In [1]:
import torch.sparse as sparse
import numpy.random as random
import numpy as np
from torchviz import make_dot
import torch
from d2l import torch as d2l

In [2]:
def synthetic_data(w, b, num_examples):  #@save
    """Generate y = Xw + b + noise."""
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w)
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4, 4.9, -1, 3.2, 0.35])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

In [3]:
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # The examples are read at random, in no particular order
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i:min(i +
                                                   batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

In [4]:
def forward(features, W_hidden, w_out):
    return torch.matmul(torch.matmul(features, W_hidden), w_out)

In [5]:
def forward_sparse(features, W_hidden, w_out):
    # import pdb; pdb.set_trace()
    hidden = sparse.mm(W_hidden, features.t())
    return sparse.mm(w_out.t(), hidden)
    # sparse.mm is very specific which matrices it gets. The first needs to sparse, the second needs to be strided/dense. So one needs to do a bunch of .t().

In [6]:
def squared_loss(y_hat, y):  #@save
    """Squared loss."""
    return (y_hat - y.reshape(y_hat.shape))**2 / 2

In [7]:
def top_kast_forward(w, kast = 0.5):
    """Selects 50% largest coefficients, not exactly like in paper, but for testing it should be fine."""
    # import pdb; pdb.set_trace()
    threshold = np.quantile(w.detach().numpy().reshape(-1), 0.5)
    mask = w < threshold
    w[mask] = 0
    return w.to_sparse().requires_grad_(True)

In [8]:
def compute_mask(w, kast = 0.5):
    # import pdb; pdb.set_trace()
    threshold = np.quantile(w.detach().numpy().reshape(-1), 0.5)
    mask = w < threshold
    return mask

In [81]:
def sgd(params, lr, batch_size):  #@save
    """Minibatch stochastic gradient descent."""
    with torch.no_grad():
        for param in params:
            import pdb; pdb.set_trace()
            param = - lr * param.grad_fn(param) / batch_size + param
            # param.grad_fn(param).zero_()

In [83]:
lr = 0.3
num_epochs = 5
net = forward_sparse
loss = squared_loss
batch_size = 10

W_hidden = torch.normal(0, 0.01, size=(6, 6), requires_grad = True).to_sparse()
w_out = torch.normal(0, 0.01, size=(6, 1), requires_grad = True).to_sparse()

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        # Make sparse forward pass
        # mask_hidden = compute_mask(W_hidden)
        # mask_out = compute_mask(w_out)
        # import pdb; pdb.set_trace()
        y_hat = net(X, W_hidden, w_out)
        l = loss(y_hat, y)  # Minibatch loss in `X` and `y`
        l.sum().backward()
        sgd([W_hidden, w_out], lr, batch_size)  # Update parameters using their gradient
    with torch.no_grad():
        train_l = loss(net(features, W_hidden, w_out), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

> <ipython-input-81-1d4ccb0012e2>(6)sgd()
-> param = - lr * param.grad_fn(param) / batch_size + param
tensor([[-0.0142, -0.0018, -0.0078,  0.0069,  0.0094, -0.0056],
        [ 0.0133,  0.0110, -0.0096, -0.0041,  0.0116,  0.0175],
        [ 0.0051,  0.0221, -0.0058, -0.0092,  0.0137,  0.0049],
        [-0.0003, -0.0025,  0.0013,  0.0089,  0.0004, -0.0026],
        [-0.0158,  0.0024, -0.0023,  0.0232, -0.0005,  0.0012],
        [-0.0012,  0.0208, -0.0080,  0.0032, -0.0088,  0.0045]])
tensor(indices=tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3,
                        3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5],
                       [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0,
                        1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5]]),
       values=tensor([-0.0142, -0.0018, -0.0078,  0.0069,  0.0094, -0.0056,
                       0.0133,  0.0110, -0.0096, -0.0041,  0.0116,  0.0175,
                       0.0051,  0.0221, -0.00

BdbQuit: 

In [44]:
W_hidden = torch.normal(0, 0.01, size=(6, 6), requires_grad = True).to_sparse()
w_out = torch.normal(0, 0.01, size=(6, 1), requires_grad = True).to_sparse()

In [45]:
y_hat = forward_sparse(features[:10,:], W_hidden, w_out)
y_hat

tensor([[ 2.3957e-04, -3.5383e-05, -1.7071e-05,  7.1259e-05,  3.4187e-04,
          2.3327e-05, -4.3133e-04,  4.7800e-05,  3.6476e-05,  5.9954e-04]],
       grad_fn=<SparseAddmmBackward>)

In [46]:
l = squared_loss(y_hat, labels[:10])
l

tensor([[ 0.2055,  0.5528, 32.2029, 74.8982, 19.1788,  6.0041,  2.2378, 39.9689,
         49.9687, 51.1675]], grad_fn=<DivBackward0>)

In [47]:
l.sum().backward()

In [60]:
w_out.grad_fn(W_hidden)

tensor([[-1.9062e-03,  1.2920e-03, -1.1964e-02, -2.4589e-02, -5.5270e-03,
         -4.1436e-03],
        [-4.9148e-03,  1.6498e-03, -2.9532e-03, -8.2215e-03,  9.5527e-03,
          4.5420e-03],
        [ 1.5493e-03, -3.6528e-03, -1.1500e-02,  1.7868e-02,  7.0601e-03,
          7.4590e-03],
        [-1.1612e-02, -4.5305e-03,  4.8653e-03,  3.1775e-03, -3.3420e-04,
          2.4657e-03],
        [-3.9694e-03,  1.8187e-02, -3.5265e-05,  2.2798e-02,  1.3404e-02,
          8.6537e-03],
        [ 5.7849e-03,  2.3365e-03,  1.0150e-02, -1.8158e-02,  8.0625e-03,
         -6.3099e-03]], grad_fn=<ToDenseBackward>)

In [43]:
sgd([W_hidden, w_out], lr, batch_size)

  if param.grad is not None:
