In [1]:
# prompt: Implement a NN that adds two bits (a0 and b0) and outputs the sum (a0+b0 = c1c0)
# use pytorch

import torch
import torch.nn as nn

# Define the neural network model
class AdderNN(nn.Module):
  def __init__(self):
    super(AdderNN, self).__init__()
    self.fc1 = nn.Linear(2, 3)  # Input: 2 bits, Output: 3 neurons
    # Q. how many parameters does self.fc1 have? 6 + 3 = 9 parameters
    # Q. in other words, how many bytes are required to store "self.fc1 " => 9 * 4 bytes (fp32)
    # Random -- LLMs
    # GPT2 was around 100M ~ 1B => (fp32) 4 bytes * 10^9 (1G) = 4GB
    # GPT3 was around 100B parameters => 400GB
    # GPT4 unknown // Llama3 500B parameters => 2TB memory => 1TB (fp16) => (INT6, INT4, INT3, INT2) => 125GB (INT2)
    # (fp16) 16 bits / parameter = norm
    self.fc2 = nn.Linear(3, 2)  # Input: 3 neurons, Output: 2 bits (sum)
    # Q. how many params? 3 * 2 + 2 = 8 parameters
    self.relu = torch.nn.ReLU()

  def forward(self, x): # [MOST IMPORTANT] the NN is defined here
    h = self.relu(self.fc1(x)) # x is the input , h is the hidden activations
    y = (self.fc2(h) > 0).float() # h is the hidden activations, y is the final output
    return y

# Create the model
model = AdderNN()


In [2]:
model

AdderNN(
  (fc1): Linear(in_features=2, out_features=3, bias=True)
  (fc2): Linear(in_features=3, out_features=2, bias=True)
  (relu): ReLU()
)

We want to add a + b = c1c0

c1 = 1 ( a + b >= 2 ) = step(ReLU( a + b - 1.5 ))
c0 = 1 ( a + b = 1 ) = step(ReLU( a + b - 0.5 ) - 2 * ReLU( a + b - 1.5))

In [3]:
model.fc1.weight.data

tensor([[-0.5941,  0.6577],
        [-0.1492,  0.3492],
        [ 0.1747,  0.1888]])

In [4]:
model.fc1.weight.data = torch.Tensor([[1, 1], [1, 1], [1, 1]])
model.fc1.bias.data = torch.Tensor([-1.5, -0.5, -1.5])
model.fc2.weight.data = torch.Tensor([[1, 0, 0], [0, 1, -2]])
model.fc2.bias.data = torch.Tensor([0, 0])

In [5]:
x = torch.Tensor([[0,0], [0,1], [1, 0], [1,1]]) # batch size =1 , the only input is [1,0]

In [6]:
model(x) # all correct!

tensor([[0., 0.],
        [0., 1.],
        [0., 1.],
        [1., 1.]])