In [226]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor

### Import MNIST dataset

In [227]:
training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()
)

### Basic information about the dataset

In [228]:
training_data.class_to_idx

{'0 - zero': 0,
 '1 - one': 1,
 '2 - two': 2,
 '3 - three': 3,
 '4 - four': 4,
 '5 - five': 5,
 '6 - six': 6,
 '7 - seven': 7,
 '8 - eight': 8,
 '9 - nine': 9}

In [229]:
len(training_data)

60000

In [230]:
len(test_data)

10000

In [231]:
img, label = training_data[0]
img.size()

torch.Size([1, 28, 28])

### Create DataLoader

In [232]:
from torch.utils.data import DataLoader

In [233]:
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [234]:
# Check mps maybe if working in MacOS
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cpu device


### Create Model with Binary Outputs

In [235]:
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

**Define sign activation function**

In [236]:
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()
        
    def forward(self, x):
        x = STEFunction.apply(x)
        return x



In [237]:
class FPNeuralNetwork(nn.Module):
    def __init__(self):
        super(FPNeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.stack = nn.Sequential(
            nn.Linear(28*28, 100),
            nn.BatchNorm1d(100),
            nn.Dropout(p=0.1),
            StraightThroughEstimator(),
            nn.Linear(100, 100),
            nn.BatchNorm1d(100),
            nn.Dropout(p=0.1),
            StraightThroughEstimator(),
            nn.Linear(100, 100),
            nn.BatchNorm1d(100),
            nn.Dropout(p=0.1),
            StraightThroughEstimator(),
            nn.Linear(100, 100),
            nn.BatchNorm1d(100),
            nn.Dropout(p=0.1),
            StraightThroughEstimator(),
            nn.Linear(100, 10),
            nn.BatchNorm1d(10),
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.stack(x)
        return F.log_softmax(logits)

net = FPNeuralNetwork().to(device)
print(net)

FPNeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (stack): Sequential(
    (0): Linear(in_features=784, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.1, inplace=False)
    (3): StraightThroughEstimator()
    (4): Linear(in_features=100, out_features=100, bias=True)
    (5): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.1, inplace=False)
    (7): StraightThroughEstimator()
    (8): Linear(in_features=100, out_features=100, bias=True)
    (9): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Dropout(p=0.1, inplace=False)
    (11): StraightThroughEstimator()
    (12): Linear(in_features=100, out_features=100, bias=True)
    (13): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Dropout(p=0.1, inplace=False)
    (15): StraightThroughEstimator()


### Model Training

In [238]:
opt = optim.Adamax(net.parameters(), lr=3e-3, weight_decay=1e-4)

In [239]:
def train(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction error
        pred = model(X)
        loss = F.nll_loss(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [240]:
def test(dataloader, model):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += F.nll_loss(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [241]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, net, opt)
    test(test_dataloader, net)
print("Done!")

Epoch 1
-------------------------------
loss: 2.709189  [   64/60000]


  return F.log_softmax(logits)


loss: 0.871715  [ 6464/60000]
loss: 0.754001  [12864/60000]
loss: 0.788550  [19264/60000]
loss: 0.734998  [25664/60000]
loss: 0.671158  [32064/60000]
loss: 0.567648  [38464/60000]
loss: 0.650778  [44864/60000]
loss: 0.534939  [51264/60000]
loss: 0.678903  [57664/60000]
Test Error: 
 Accuracy: 88.3%, Avg loss: 0.400633 

Epoch 2
-------------------------------
loss: 0.530654  [   64/60000]
loss: 0.481726  [ 6464/60000]
loss: 0.484848  [12864/60000]
loss: 0.581437  [19264/60000]
loss: 0.399799  [25664/60000]
loss: 0.590558  [32064/60000]
loss: 0.337829  [38464/60000]
loss: 0.653213  [44864/60000]
loss: 0.427863  [51264/60000]
loss: 0.707894  [57664/60000]
Test Error: 
 Accuracy: 89.7%, Avg loss: 0.352422 

Epoch 3
-------------------------------
loss: 0.294213  [   64/60000]
loss: 0.575122  [ 6464/60000]
loss: 0.447706  [12864/60000]
loss: 0.526050  [19264/60000]
loss: 0.477657  [25664/60000]
loss: 0.545274  [32064/60000]
loss: 0.433417  [38464/60000]
loss: 0.596418  [44864/60000]
loss: 

In [258]:
for name, param in net.named_parameters():
    print(name)
    if name.startswith('stack.1.'):
        print(param.size())


stack.0.weight
stack.0.bias
stack.1.weight
torch.Size([100])
stack.1.bias
torch.Size([100])
stack.4.weight
stack.4.bias
stack.5.weight
stack.5.bias
stack.8.weight
stack.8.bias
stack.9.weight
stack.9.bias
stack.12.weight
stack.12.bias
stack.13.weight
stack.13.bias
stack.16.weight
stack.16.bias
stack.17.weight
stack.17.bias


In [253]:
for idx, m in enumerate(net.named_modules()):
    print(f'{idx} -> {m}')
    break

0 -> ('', FPNeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (stack): Sequential(
    (0): Linear(in_features=784, out_features=100, bias=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.1, inplace=False)
    (3): StraightThroughEstimator()
    (4): Linear(in_features=100, out_features=100, bias=True)
    (5): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.1, inplace=False)
    (7): StraightThroughEstimator()
    (8): Linear(in_features=100, out_features=100, bias=True)
    (9): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Dropout(p=0.1, inplace=False)
    (11): StraightThroughEstimator()
    (12): Linear(in_features=100, out_features=100, bias=True)
    (13): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Dropout(p=0.1, inplace=False)
    (15): StraightThroughEs