In [1]:
from IPython.core.debugger import set_trace

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from models import *
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [3]:
class AsyncBatchNorm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, beta, gamma, r_mu, r_sigma2):
        hath = (input - r_mu) * (r_sigma2 + 1e-5)**(-1./2.)
        ctx.save_for_backward(input, beta, gamma)
        print('wat')
        return gamma * hath + beta

    @staticmethod
    def backward(ctx, grad_output):
        input, beta, gamma = ctx.saved_tensors
        print('wut')
        dy = grad_output
        N = input.shape[0]
        eps = 1e-5

        mu = 1/N * torch.sum(input, dim=0)  # Size (H,) maybe torch.mean is faster
        sigma2 = 1/N * torch.sum((input - mu)**2, dim=0)  # Size (H,) maybe torch variance is faster

        dx = (1. / N) * gamma * (sigma2 + eps)**(-1. / 2.) * (N * dy - torch.sum(dy, dim=0)
                - (input - mu) * (sigma2 + eps)**(-1.0) * torch.sum(dy * (input - mu), dim=0))

        dbeta = torch.sum(dy, dim=0)
        dgamma = torch.sum((input - mu) * (sigma2 + eps)**(-1. / 2.) * dy, dim=0)

        return dx, dbeta, dgamma, None, None

In [4]:
class TrailingBatchNorm(torch.nn.Module):
    def __init__(self, gamma=1., beta=0.):
        super(TrailingBatchNorm, self).__init__()

        self.running_mean = None

        self.temp_running_mean = 0
        self.temp_running_variance = 0

        self.gamma = torch.tensor([1.], requires_grad=True).cuda()
        self.beta = torch.tensor([0.], requires_grad=True).cuda()
        self.eps = 1e-5
        self.momentum = 0.9  # uses the same momentum definition as pytorch batchnorm

        self.first = 0

    def forward(self, x):
        if x.dim() != 4:
            return self.batchnorm_forward_conv(x, self.gamma, self.beta)
        else:
            N, C, H, W = x.shape
            x_new = x.transpose(0, 1).contiguous().view(x.size(1), -1)
            # x_new = x.permute(0, 2, 3, 1).reshape(N * H * W, C)
            out = self.batchnorm_forward(x_new, self.gamma, self.beta)
            out = out.view(C, N, H, W).transpose(0, 1)
            return out
        
    def batchnorm_forward_conv(self, x, gamma, beta):
        if x.requires_grad:
            self.running_mean = self.temp_running_mean
            self.running_variance = self.temp_running_variance

        set_trace()
        x_chan = x.transpose(0, 1).contiguous().view(x.size(1), -1)
        
        async_batchnorm = AsyncBatchNorm.apply
        momentum = self.momentum
        eps = self.eps

        if self.first == 0:
            self.first += 1
            out = x
            if self.first == 1:
                self.running_mean = x_chan.mean(1)[:, None, None]
                self.running_variance = x_chan.std(1)[:, None, None]
        else:
            out = async_batchnorm(x, self.beta, self.gamma, self.running_mean, self.running_variance)

        if x.requires_grad:
            with torch.no_grad():
                sample_mean = x_chan.mean(1)[:, None, None]  # we do not need to calculate gradients over mean/var
                sample_var = x_chan.std(1)[:, None, None]

                self.temp_running_mean = (1 - momentum) * self.running_mean.detach() +  momentum * sample_mean
                self.temp_running_variance = (1 - momentum) * self.running_variance.detach() + momentum * sample_var

        return out

    def batchnorm_forward(self, x, gamma, beta):
        if x.requires_grad:
            self.running_mean = self.temp_running_mean
            self.running_variance = self.temp_running_variance

        async_batchnorm = AsyncBatchNorm.apply
        momentum = self.momentum
        eps = self.eps

        if self.first == 0:
            self.first += 1
            out = x
            if self.first == 1:
                self.running_mean = torch.mean(x, dim=0).detach()
                self.running_variance = torch.var(x, dim=0).detach()
        else:
            out = async_batchnorm(x, self.beta, self.gamma, self.running_mean, self.running_variance)

        if x.requires_grad:
            with torch.no_grad():
                sample_mean = torch.mean(x, dim=0)  # we do not need to calculate gradients over mean/var
                sample_var = torch.var(x, dim=0)

                self.temp_running_mean = (1 - momentum) * self.running_mean.detach() +  momentum * sample_mean
                self.temp_running_variance = (1 - momentum) * self.running_variance.detach() + momentum * sample_var

        return out


In [5]:
class Bottleneck_TBN(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck_TBN, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = TrailingBatchNorm()
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = TrailingBatchNorm()
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = TrailingBatchNorm()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                TrailingBatchNorm()
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet_TBN(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet_TBN, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = TrailingBatchNorm()
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.batch_norm_conv = TrailingBatchNorm()
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.batch_norm_conv2 = TrailingBatchNorm()
        self.conv2_drop = nn.Dropout2d()
        self.batch_norm = TrailingBatchNorm()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = self.batch_norm_conv(x)
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = self.batch_norm_conv2(x)
        x = x.contiguous().view(-1, 320)
#         x = self.batch_norm(x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
        batch_size=64, shuffle=True)

In [6]:
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=64, shuffle=True, num_workers=2)

==> Preparing data..
Files already downloaded and verified


In [7]:
model = ResNet_TBN(Bottleneck_TBN, [3,4,6,3])
model = model.cuda()

In [None]:
model = Net()
model = model.cuda()

In [8]:
optimizer = optim.SGD(model.parameters(), lr=1e-5, momentum=0)

In [9]:
model.train()

ResNet_TBN(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): TrailingBatchNorm()
  (layer1): Sequential(
    (0): Bottleneck_TBN(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): TrailingBatchNorm()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): TrailingBatchNorm()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): TrailingBatchNorm()
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): TrailingBatchNorm()
      )
    )
    (1): Bottleneck_TBN(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): TrailingBatchNorm()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): TrailingBatchNorm()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1

In [10]:
data = None
for i, (data, target) in enumerate(trainloader):
    data, target = data.cuda(), target.cuda()
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    print(i)
    if i == 5:
        break

0
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
1
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
wut
2
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wat
wa

Process Process-1:
Process Process-2:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/hans/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/hans/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/hans/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/hans/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/home/hans/anaconda/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/hans/anaconda/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/hans/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File 

KeyboardInterrupt: 

In [None]:
%debug