# Toy models of Superposition

In [2]:
# import packages
import torch

# create model
class LinearModel():
    def __init__(self, input_size, hidden_size):
        self.weights = torch.randn(size=(input_size, hidden_size))
        self.bias = torch.zeros((1,input_size))
        self.parameters = [self.bias, self.weights]

    def forward(self, x):
        x = x @ self.weights
        x = x @ torch.transpose(self.weights, 0, 1) + self.bias
        return x
    
    def train_step(self, x, y, lr = 0.01):
        # forward pass
        out = self.forward(x)

        # compute loss
        loss = torch.mean((out - y)**2)

        # Backward pass
        loss.backward()

        # Update and reset gradients
        for p in self.parameters:
            p.data += -lr * p.grad
            p.grad.zero()
        
        return loss.item()



In [11]:
#  hyperparameters
S = 0.1
input_size = 20
hidden_size = 5
sample_size = 1000
batch_size = 20

# create data and randomly set elements to zero to mimic sparsity
data = torch.rand((10, input_size))
mask = torch.rand(data.shape) < S
data[mask] = 0

# Initialize model
model = LinearModel(input_size, hidden_size)








tensor([[ 2.0013e+00,  1.0201e+00, -1.4404e+00, -6.9425e+00, -5.1367e-01,
          3.6418e+00,  1.1280e+01, -7.4025e+00, -1.2526e+00,  5.1069e+00,
          9.3632e+00,  2.0530e+00, -3.8955e+00,  8.4232e-01,  5.0814e+00,
         -2.2984e+00,  9.6399e+00,  3.8838e+00,  3.5740e+00, -3.2782e+00],
        [-1.8588e+00,  6.9348e-01, -4.6115e-01,  2.1384e+00, -2.8335e+00,
          7.5670e-01,  5.1834e+00, -3.3768e+00,  2.9413e+00,  1.4905e+00,
          2.2879e+00,  1.0230e+00, -3.1204e+00, -5.4946e-01,  1.7177e+00,
          4.7721e-01,  5.3647e+00,  4.1893e+00, -7.3679e-02, -1.0983e+00],
        [ 3.1688e+00,  1.5213e+00, -6.1872e+00, -7.6775e+00, -6.2321e-02,
          1.4928e+00,  4.8097e+00,  2.0192e+00, -2.8970e+00,  2.3544e+00,
          9.3619e+00, -2.8311e+00, -6.5694e+00,  4.3184e-01,  1.1473e+00,
         -2.1794e+00,  6.5431e-01,  9.6448e+00,  5.0926e+00, -1.7585e+00],
        [ 3.0210e+00, -1.5590e+00, -5.7467e+00, -5.6498e+00,  1.4201e+00,
          2.2157e+00,  5.3241e+00, 

In [21]:
model = LinearModel(20, 5)
v = torch.ones(20)
model.forward(v)


tensor([[  6.7475,  -5.8833,   2.9287,  10.9396, -14.5848,  16.6833,  -0.1689,
           6.5663,   3.2789,  -1.1569, -10.3093,   8.9053,   3.2356,   9.7727,
          11.1480,  -9.3307,   4.5352,   1.9709,  20.1343,   5.1474]])