In [0]:
import torch
import pdb
import torch.nn as nn
import math
from torch.autograd import Variable
from torch.autograd import Function
import time

import numpy as np


def Binarize(tensor,quant_mode='det'):
    if quant_mode=='det':
        return tensor.sign()
    else:
        return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)




class HingeLoss(nn.Module):
    def __init__(self):
        super(HingeLoss,self).__init__()
        self.margin=1.0

    def hinge_loss(self,input,target):
            #import pdb; pdb.set_trace()
            output=self.margin-input.mul(target)
            output[output.le(0)]=0
            return output.mean()

    def forward(self, input, target):
        return self.hinge_loss(input,target)

class SqrtHingeLossFunction(Function):
    def __init__(self):
        super(SqrtHingeLossFunction,self).__init__()
        self.margin=1.0

    def forward(self, input, target):
        output=self.margin-input.mul(target)
        output[output.le(0)]=0
        self.save_for_backward(input, target)
        loss=output.mul(output).sum(0).sum(1).div(target.numel())
        return loss

    def backward(self,grad_output):
       input, target = self.saved_tensors
       output=self.margin-input.mul(target)
       output[output.le(0)]=0
       import pdb; pdb.set_trace()
       grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)
       grad_output.mul_(output.ne(0).float())
       grad_output.div_(input.numel())
       return grad_output,grad_output

def Quantize(tensor,quant_mode='det',  params=None, numBits=8):
    tensor.clamp_(-2**(numBits-1),2**(numBits-1))
    if quant_mode=='det':
        tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
    else:
        tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
        quant_fixed(tensor, params)
    return tensor

import torch.nn._functions as tnnf


class BinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeLinear, self).__init__(*kargs, **kwargs)

    def forward(self, input):

        if input.size(1) != 784:
            input.data=Binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)
        out = nn.functional.linear(input, self.weight)
        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)

        return out

class BinarizeConv2d(nn.Conv2d):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeConv2d, self).__init__(*kargs, **kwargs)


    def forward(self, input):
        if input.size(1) != 3:
            input.data = Binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)

        out = nn.functional.conv2d(input, self.weight, None, self.stride,
                                   self.padding, self.dilation, self.groups)

        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1, 1, 1).expand_as(out)

        return out
      



In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tqdm import tqdm
# from models.binarized_modules import  BinarizeLinear,BinarizeConv2d
# from models.binarized_modules import  Binarize,Ternarize,Ternarize2,Ternarize3,Ternarize4,HingeLoss
# Training settings
# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
# parser.add_argument('--batch-size', type=int, default=64, metavar='N',
#                     help='input batch size for training (default: 256)')
# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
#                     help='input batch size for testing (default: 1000)')
# parser.add_argument('--epochs', type=int, default=100, metavar='N',
#                     help='number of epochs to train (default: 10)')
# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
#                     help='learning rate (default: 0.001)')
# parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
#                     help='SGD momentum (default: 0.5)')
# parser.add_argument('--no-cuda', action='store_true', default=False,
#                     help='disables CUDA training')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                     help='random seed (default: 1)')
# parser.add_argument('--gpus', default=3,
#                     help='gpus used for training - e.g 0,1,3')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                     help='how many batches to wait before logging training status')
# args = parser.parse_args()
# args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(1)
# if args.cuda:
#     torch.cuda.manual_seed(args.seed)


# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=128, shuffle=True)


# 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
#         self.fc1 = BinarizeLinear(784, 2048*self.infl_ratio)
       
        
#         self.infl_ratio=3
#         self.fc1 = BinarizeLinear(784, 2048*self.infl_ratio)
#         self.htanh1 = nn.Hardtanh()
#         self.bn1 = nn.BatchNorm1d(2048*self.infl_ratio)
#         self.fc2 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)
#         self.htanh2 = nn.Hardtanh()
#         self.bn2 = nn.BatchNorm1d(2048*self.infl_ratio)
#         self.fc3 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)
#         self.htanh3 = nn.Hardtanh()
#         self.bn3 = nn.BatchNorm1d(2048*self.infl_ratio)
#         self.fc4 = nn.Linear(2048*self.infl_ratio, 10)
#         self.logsoftmax = nn.LogSoftmax(dim=1)
#         self.drop=nn.Dropout(0.5)
        
        self.conv1 = BinarizeConv2d(1, 32, kernel_size=3)
        self.mp1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.htanh1 = nn.Hardtanh()
        
        self.conv2 = BinarizeConv2d(32, 64, kernel_size=3)
        self.mp2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.htanh2 = nn.Hardtanh()
        
        self.fc1 = nn.Linear(64*5*5, 512)
        self.bn3 = nn.BatchNorm1d(512)
        self.htanh3 = nn.Hardtanh()
        
        self.fc2 = nn.Linear(512, 10)
        self.sm = nn.Softmax(dim=1)
        
          

    def forward(self, x):
#         x = x.view(-1, 28*28)
        x = Binarize(x)
        x = self.conv1(x)
        x = self.mp1(x)
        x = self.bn1(x)
        x = self.htanh1(x)
        x = Binarize(x)
        
        x = self.conv2(x)
        x = self.mp2(x)
        x = self.bn2(x)
        x = self.htanh2(x)
        x = Binarize(x)
        
#         print(x.shape)
#         x = x.view(-1, 64*5*5)
        x = x.view(x.size(0), -1)
        
        
        x = self.fc1(x)
        x = self.bn3(x)
        x = self.htanh3(x)
        x = Binarize(x)
        
        x = self.fc2(x)
     

        return self.sm(x)

model = Net()
torch.cuda.device('cuda')
model.cuda()


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)


def train(epoch):
    model.train()
    
    losses = []
    trainloader = tqdm(train_loader)
    
    for batch_idx, (data, target) in enumerate(trainloader):
 
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

#         if epoch%40==0:
#             optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1

#         optimizer.zero_grad()
        
        loss.backward()
    
        for p in list(model.parameters()):
            if hasattr(p,'org'):
                p.data.copy_(p.org)
        optimizer.step()
        
        for p in list(model.parameters()):
            if hasattr(p,'org'):
                p.org.copy_(p.data.clamp_(-1,1))
    
        losses.append(loss.item())
        trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)
#         if batch_idx % 10000 == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), loss.item()))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    testloader = tqdm(test_loader)
    for data, target in testloader:
        data, target = data.cuda(), target.cuda()
        with torch.no_grad():
          data = Variable(data)
        target = Variable(target)
        output = model(data)
        test_loss += criterion(output, target).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
        

        testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')
    
    test_loss /= len(test_loader.dataset)
    
    
#     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))


for epoch in range(5):
    train(epoch)
    test()

100%|██████████| 469/469 [00:13<00:00, 34.80it/s, epoch=0, loss=1.6]
100%|██████████| 79/79 [00:01<00:00, 47.50it/s, acc=89%, loss=0.0124]
100%|██████████| 469/469 [00:11<00:00, 39.50it/s, epoch=1, loss=1.57]
100%|██████████| 79/79 [00:01<00:00, 47.74it/s, acc=89%, loss=0.0124]
100%|██████████| 469/469 [00:11<00:00, 39.42it/s, epoch=2, loss=1.57]
100%|██████████| 79/79 [00:01<00:00, 47.80it/s, acc=89%, loss=0.0124]
100%|██████████| 469/469 [00:11<00:00, 39.37it/s, epoch=3, loss=1.57]
100%|██████████| 79/79 [00:01<00:00, 47.96it/s, acc=89%, loss=0.0124]
100%|██████████| 469/469 [00:11<00:00, 39.30it/s, epoch=4, loss=1.57]
100%|██████████| 79/79 [00:01<00:00, 48.08it/s, acc=89%, loss=0.0123]
