In [1]:
from BinaryLayers import *

In [2]:
# For Binary Layers
H = 1
deterministic = True

# For Training
batch_size = 50
num_epochs = 500

# For batchnorm
epsilon = 1e-4
alpha = .1

# For the optimizer
lr = .003

# Set GPU
device = torch.device("cuda")

In [3]:
class nn(Module):
    def __init__(self):
        super(nn, self).__init__()

        self.c3_128_1 = BinaryConv2D(3, 128, 3, padding=1, H=H, deterministic=deterministic)
        self.bn_1 = torch.nn.BatchNorm2d(128, eps=epsilon, momentum=alpha)

        self.c3_128_2 = BinaryConv2D(128, 128, 3, padding=1, H=H, deterministic=deterministic)
        self.mp2_2 = torch.nn.MaxPool2d(2)
        self.bn_2 = torch.nn.BatchNorm2d(128, eps=epsilon, momentum=alpha)

        self.c3_256_3 = BinaryConv2D(128, 256, 3, padding=1, H=H, deterministic=deterministic)
        self.bn_3 = torch.nn.BatchNorm2d(256, eps=epsilon, momentum=alpha)

        self.c3_256_4 = BinaryConv2D(256, 256, 3, padding=1, H=H, deterministic=deterministic)
        self.mp2_4 = torch.nn.MaxPool2d(2)
        self.bn_4 = torch.nn.BatchNorm2d(256, eps=epsilon, momentum=alpha)

        self.c3_512_5 = BinaryConv2D(256, 512, 3, padding=1, H=H, deterministic=deterministic)
        self.bn_5 = torch.nn.BatchNorm2d(512, eps=epsilon, momentum=alpha)

        self.c3_512_6 = BinaryConv2D(512, 512, 3, padding=1, H=H, deterministic=deterministic)
        self.mp2_6 = torch.nn.MaxPool2d(2)
        self.bn_6 = torch.nn.BatchNorm2d(512, eps=epsilon, momentum=alpha)

        self.d_1024_7 = BinaryDense(2048 * 4, 1024, H=H, deterministic=deterministic)
        self.bn_7 = torch.nn.BatchNorm1d(1024, eps=epsilon, momentum=alpha)
        
        self.d_1024_8 = BinaryDense(1024, 1024, H=H, deterministic=deterministic)
        self.bn_8 = torch.nn.BatchNorm1d(1024, eps=epsilon, momentum=alpha)

        self.d_10_9 = BinaryDense(1024, 10, H=H, deterministic=deterministic)

    def forward(self, input):
        x = self.c3_128_1(input)
        x = torch.nn.ReLU()(self.bn_1(x))
#         print(x.shape)
        x = self.c3_128_2(x)
        x = self.mp2_2(x)
        x = torch.nn.ReLU()(self.bn_2(x))
#         print(x.shape)

        x = self.c3_256_3(x)
        x = torch.nn.ReLU()(self.bn_3(x))
#         print(x.shape)
        x = self.c3_256_4(x)
        x = self.mp2_4(x)
        x = torch.nn.ReLU()(self.bn_4(x))
#         print(x.shape)

        x = self.c3_512_5(x)
        x = torch.nn.ReLU()(self.bn_5(x))
#         print(x.shape)
        x = self.c3_512_6(x)
        x = self.mp2_6(x)
        x = torch.nn.ReLU()(self.bn_6(x))
#         print(x.shape)

        x = torch.nn.Flatten()(x)
        
        x = self.d_1024_7(x)
        x = torch.nn.ReLU()(self.bn_7(x))
#         print(x.shape)

        x = self.d_1024_8(x)
        x = torch.nn.ReLU()(self.bn_8(x))
#         print(x.shape)

        x = self.d_10_9(x)
        x = torch.nn.Softmax(-1)(x)
        return x

In [4]:
model = nn().to(device)

In [5]:
model

nn(
  (c3_128_1): BinaryConv2D(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_1): BatchNorm2d(128, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
  (c3_128_2): BinaryConv2D(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (mp2_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bn_2): BatchNorm2d(128, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
  (c3_256_3): BinaryConv2D(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_3): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
  (c3_256_4): BinaryConv2D(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (mp2_4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bn_4): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=True, track_running_stats=True)
  (c3_512_5): BinaryConv2D(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn_5): BatchNor

In [6]:
t = transforms.Compose(
    [
       transforms.ToTensor(),
       transforms.Normalize(mean=(0), std=(1))
    ]
)

dl_train = DataLoader(
    torchvision.datasets.CIFAR10(
        "/data/cifar",
        download=True,
        train=True,
        transform=t,
        target_transform=torchvision.transforms.Compose([
            lambda x:torch.LongTensor([x]), # or just torch.tensor
            lambda x:torch.nn.functional.one_hot(x, 10)
        ])
    ),
    batch_size=batch_size,
    drop_last=True,
    shuffle=True
)
dl_test = DataLoader(
    torchvision.datasets.CIFAR10(
        "/data/cifar",
        download=True,
        train=False,
        transform=t,
        target_transform=torchvision.transforms.Compose([
            lambda x:torch.LongTensor([x]), # or just torch.tensor
            lambda x:torch.nn.functional.one_hot(x, 10)
        ])
    ),
    batch_size=batch_size,
    drop_last=True,
    shuffle=True
)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
loss = SquareHingeLoss
# loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [8]:
def train(model, num_epochs, dl_train, dl_valid, optimizer, lossfunction):
    losses = [0] * num_epochs
    val_losses = [0] * num_epochs
    total_steps = len(dl_train) * num_epochs
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}")
        # Start training
        model.train()
        for i, (input, target) in enumerate(dl_train):
            optimizer.zero_grad()
#             input = torch.reshape(input, (-1, 32 * 32)).to(device)
            input = input.to(device)
            target = torch.reshape(target, (-1, 10)).to(device)
            output = model(input)

            loss = lossfunction(output, target.float())
            losses[epoch] += loss.item()
            loss.backward()
            
            optimizer.step()
            if (i + 1) % 10 == 0:
                print(i + 1, "{:.04}".format(losses[epoch] / (i + 1)), end="\t")
            
        model.eval()
        tot_acc = 0
        with torch.no_grad():
            for j, (input, target) in enumerate(dl_valid):
#                 input = torch.reshape(input, (-1, 32 * 32)).to(device)
                input = input.to(device)
                target = target.reshape((-1, 10)).to(device)
                output = model(input)
                loss = lossfunction(output, target.float())
                val_losses[epoch] += loss.item()
                tot_acc = (tot_acc * j + int(sum(torch.argmax(target, -1) == torch.argmax(output, -1))) / len(target)) / (j + 1)
        print("")
        print("Epoch training loss" , losses[epoch] / len(dl_train))
        print("Epoch valid loss" , val_losses[epoch] / len(dl_valid))
        print("Validation Accuracy:", tot_acc)
    return losses, val_losses,

In [9]:
train(model, num_epochs, dl_train, dl_test, optimizer, loss)

Epoch 1
10 0.247	20 0.2471	30 0.247	40 0.2471	50 0.2469	60 0.2469	70 0.247	80 0.2469	90 0.2469	100 0.2469	110 0.2469	120 0.2469	130 0.2469	140 0.2469	150 0.2469	160 0.2469	170 0.2469	180 0.2469	190 0.247	200 0.247	210 0.247	220 0.247	230 0.247	240 0.247	250 0.2471	260 0.2471	270 0.2471	280 0.2471	290 0.2471	300 0.2471	310 0.2472	320 0.2471	330 0.2472	340 0.2472	350 0.2471	360 0.2472	370 0.2472	380 0.2472	390 0.2472	400 0.2472	410 0.2472	420 0.2472	430 0.2472	440 0.2472	450 0.2472	460 0.2472	470 0.2472	480 0.2472	490 0.2472	

KeyboardInterrupt: 