In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
from researchlib.single_import import *

# 1. Binary Linear Model

In [2]:
model = builder([
    Reshape((-1, 784)),
    layer.BinarizeLinear(784, 512),
    nn.BatchNorm1d(512),
    nn.Hardtanh(),
    layer.BinarizeLinear(512, 256),
    nn.BatchNorm1d(256),
    nn.Hardtanh(),
    layer.BinarizeLinear(256, 128),
    nn.BatchNorm1d(128),
    nn.Hardtanh(),
    layer.BinarizeLinear(128, 10),
    nn.LogSoftmax(-1),
])

In [13]:
train_loader = FromVisionDataset(vision.MNIST(True), batch_size=128)
test_loader = FromVisionDataset(vision.MNIST(False), batch_size=128)

In [14]:
runner = Runner(model, train_loader, test_loader, 'adam', 'nll')

In [15]:
runner.init_model()

Init xavier_normal: BinarizeLinear(in_features=784, out_features=512, bias=True)
Init xavier_normal: BinarizeLinear(in_features=512, out_features=256, bias=True)
Init xavier_normal: BinarizeLinear(in_features=256, out_features=128, bias=True)
Init xavier_normal: BinarizeLinear(in_features=128, out_features=10, bias=True)
Init xavier_normal: Sequential(
  (0): Reshape()
  (1): BinarizeLinear(in_features=784, out_features=512, bias=True)
  (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Hardtanh(min_val=-1.0, max_val=1.0)
  (4): BinarizeLinear(in_features=512, out_features=256, bias=True)
  (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Hardtanh(min_val=-1.0, max_val=1.0)
  (7): BinarizeLinear(in_features=256, out_features=128, bias=True)
  (8): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): Hardtanh(min_val=-1.0, max_val=1.0)
  (10): BinarizeLinear(in_featu

In [17]:
runner.fit(4, 1e-2, callbacks=[Binarized()])

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

  Epochs    train_loss    train_acc      val_loss      val_acc    
    1         0.8334        0.9308        1.1069        0.9215    


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

    2         0.7348        0.9399        0.4221        0.9645    


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

    3         0.6158        0.9481        0.9481        0.9336    


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

    4         0.6126        0.9510        0.4871        0.9631    



# 2. Binary Convolution Model

In [5]:
model = builder([
    layer.BinarizeConv2d(1, 64, 5),
    nn.BatchNorm2d(64),
    nn.Hardtanh(),
    nn.MaxPool2d(2),
    layer.BinarizeConv2d(64, 128, 5),
    nn.BatchNorm2d(128),
    nn.Hardtanh(),
    nn.MaxPool2d(2),
    Flatten(),
    layer.BinarizeLinear(2048, 10),
    nn.LogSoftmax(-1)
])

In [20]:
runner = Runner(model, train_loader, test_loader, 'adam', 'nll')

In [21]:
runner.init_model()

Init xavier_normal: BinarizeConv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
Init xavier_normal: BinarizeConv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
Init xavier_normal: BinarizeLinear(in_features=2048, out_features=10, bias=True)
Init xavier_normal: Sequential(
  (0): BinarizeConv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): Hardtanh(min_val=-1.0, max_val=1.0)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): BinarizeConv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): Hardtanh(min_val=-1.0, max_val=1.0)
  (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (8): Flatten()
  (9): BinarizeLinear(in_features=2048, out_features=10, bias=True)
  (10): LogSoftmax()
)
Init xavier_normal: Sequential(
  (0): BinarizeConv2d(1, 64, kernel_

In [22]:
runner.fit(4, 1e-2, callbacks=[Binarized()])

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

  Epochs    train_loss    train_acc      val_loss      val_acc    
    1*        5.5719        0.8552        3.0029        0.9415    


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

    2         2.5013        0.9564        0.8964        0.9798    


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

    3         2.1064        0.9609        2.3800        0.9567    


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

    4         1.9188        0.9667        0.7076        0.9823    
