In [1]:
import torch
import torch.nn as nn
from lightstream.scnn import StreamingCNN, StreamingConv2d
from torchvision.models import resnet18, resnet34, resnet50

In [2]:
torch.set_printoptions(precision=10)

## Model definition

In [3]:
resnet = resnet18(weights="IMAGENET1K_V1")

def split_model(model):
    model.layer4 = torch.nn.Sequential()
    stream_net = nn.Sequential(
        model.conv1,
        model.bn1,
        model.relu,
        model.maxpool,
        model.layer1,
        model.layer2,
        model.layer3,
        model.layer4,
        torch.nn.MaxPool2d(8, stride=8, ceil_mode=False)
    )
    head = nn.Sequential(model.avgpool, nn.Flatten(), model.fc)
    return stream_net, head

In [4]:
stream_net, head = split_model(resnet)

def freeze_bn_layers(model):
    for mod in model.modules():
        if isinstance(mod, torch.nn.BatchNorm2d):
            mod.eval()

freeze_bn_layers(stream_net)

In [5]:
"""
padding = 0

stream_net = torch.nn.Sequential(
    torch.nn.Conv2d(3, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2))
"""

'\npadding = 0\n\nstream_net = torch.nn.Sequential(\n    torch.nn.Conv2d(3, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),\n    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),\n    torch.nn.MaxPool2d(2),\n    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),\n    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),\n    torch.nn.MaxPool2d(2),\n    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),\n    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),\n    torch.nn.MaxPool2d(2))\n'

In [6]:
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        layer.weight.data *= 1.0
        
        if layer.bias is not None:
            layer.bias.data.zero_()

In [7]:
print(stream_net)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

## Configurations

In [8]:
tile_size = 128*15
img_size = 128*25

cuda = True  # execute this notebook on the GPU
verbose = True   # enable / disable logging
dtype = torch.float64  # test with double precision

In [9]:

stream_net.type(dtype)
if cuda: stream_net.cuda()

## Configure streamingCNN
IMPORTANT: setting gather_gradients to True makes the class save all the gradients of the intermediate feature maps. This is needed because we want to compare the feature map gradients between streaming and conventional backpropagation. However this also counteracts the memory gains by StreamingCNN. If you want to test the memory efficiency, set gather_gradients to False

In [10]:
sCNN = StreamingCNN(stream_net, 
                    tile_shape=(1, 3, tile_size, tile_size), 
                    verbose=True,
                    saliency=True)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 
 Lost(top:2.0, left:2.0, bottom:1.0, right:1.0)
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) 
 Lost(top:2.0, left:2.0, bottom:1.0, right:1.0)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 
 Lost(top:3.0, left:3.0, bottom:2.0, right:2.0)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 
 Lost(top:4.0, left:4.0, bottom:3.0, right:3.0)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 
 Lost(top:5.0, left:5.0, bottom:4.0, right:4.0)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 
 Lost(top:6.0, left:6.0, bottom:5.0, right:5.0)
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 
 Lost(top:4.0, left:4.0, bottom:3.0, right:3.0)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 
 Lost(top:5.0, left:5.0, bottom:4

If the verbose flag is True than StreamingCNN will print for every layer in the network the required overlap that is needed to reconstruct the feature maps and gradients. The higher this is, the more tiles are needed to be inferences. It is always beneficial to increase the tile size as much as possible to make use of all the GPU memory.

## Generate random image and fake label

In [11]:
image = torch.FloatTensor(3, img_size, img_size).normal_(0, 1)
target = torch.tensor(50.)  # large value so we get larger gradients

image = image.type(dtype)
target = target.type(dtype)

if cuda:
    target = target.cuda()
    image = image.cuda()

In [12]:
criterion = torch.nn.MSELoss()

## Run through network using streaming

In [13]:
stream_output = sCNN.forward(image[None])
print(stream_output.shape)
stream_output.max()

torch.Size([1, 256, 25, 25])


tensor(1.6426521601, device='cuda:0', dtype=torch.float64)

In [14]:
stream_output.requires_grad = True

In [15]:
output = torch.sigmoid(torch.mean(stream_output)); output

tensor(0.5622967603, device='cuda:0', dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)

In [16]:
loss = criterion(output, target)
loss

tensor(2444.0865016126, device='cuda:0', dtype=torch.float64,
       grad_fn=<MseLossBackward0>)

In [17]:
loss.backward()

In [18]:
print(stream_output.shape)
print(stream_output.grad.shape)
full_gradients = sCNN.backward(image[None], stream_output.grad)

torch.Size([1, 256, 25, 25])
torch.Size([1, 256, 25, 25])


In [19]:
sCNN.saliency_map.shape

torch.Size([1, 3, 3200, 3200])

In [20]:
streaming_conv_gradients = []

for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, StreamingConv2d):
        if layer.weight.grad is not None:
            streaming_conv_gradients.append(layer.weight.grad.clone()) 

In [21]:
sCNN.disable()

streaming_conv_gradients


[tensor([[[[-6.1936418192e-03,  1.4930331229e-03,  3.0730868068e-03,
             ...,  2.3383312697e-04,  1.9941525800e-03,
             1.0957454993e-02],
           [-3.1255359912e-03, -2.2401213888e-04, -1.0065916789e-03,
             ...,  1.4354475930e-02,  1.8567064770e-03,
             5.0425057378e-03],
           [-4.8022448455e-03,  6.6527715808e-04, -5.8017796134e-03,
             ..., -8.8509797163e-03, -6.0979249873e-03,
            -1.9851248127e-05],
           ...,
           [-1.5625450686e-03,  5.1801123736e-04, -2.9995581110e-03,
             ...,  6.4762936050e-03,  2.6887309108e-03,
             2.2248721719e-03],
           [-1.2896846083e-03,  2.4186087204e-03, -1.1747042747e-04,
             ..., -4.3508782937e-03, -7.3422008318e-03,
            -2.5769987889e-03],
           [ 1.3542124476e-03, -7.4192189357e-05,  2.4177188221e-03,
             ...,  4.4700808766e-03, -2.4847236489e-03,
            -1.9376310044e-03]],
 
          [[-7.3105460966e-03, -9.99253

## Compare to conventional training

In [22]:
resnet = resnet18(weights="IMAGENET1K_V1")
stream_net, head = split_model(resnet)

stream_net, head = split_model(resnet)

def freeze_bn_layers(model):
    for mod in model.modules():
        if isinstance(mod, torch.nn.BatchNorm2d):
            mod.eval()

freeze_bn_layers(stream_net)
stream_net.type(dtype)
if cuda: stream_net.cuda()


for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        if layer.weight.grad is not None:
            layer.weight.grad.data.zero_()
            layer.bias.grad.data.zero_()
            


In [23]:
conventional_gradients = []
inps = []

def save_grad(module, grad_in, grad_out):
    global conventional_gradients
    conventional_gradients.append(grad_out[0].clone())
        
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        layer.register_backward_hook(save_grad)

This output should be the same as the streaming output, if so, the loss will also be the same:



In [24]:
image.requires_grad = True
stream_net.to('cpu')
conventional_output = stream_net(image[None].to('cpu')); conventional_output.max()
conventional_output.shape



torch.Size([1, 256, 25, 25])

In [25]:
print(conventional_output.shape)
stream_output.shape

torch.Size([1, 256, 25, 25])


torch.Size([1, 256, 25, 25])

In [26]:
# NOTE: sometimes output can be slightly bigger 
# (if tiles do not fit nicely on input image according to output stride)
# In that case this check may fail.
print(stream_output.shape, conventional_output.shape)
max_error = torch.abs(stream_output.detach().cpu() - conventional_output).max().item()

if max_error < 1e-7:
    print("Equal output to streaming")
else:
    print("NOT equal output to streaming"),
    print("error:", max_error)

torch.Size([1, 256, 25, 25]) torch.Size([1, 256, 25, 25])
Equal output to streaming


In [27]:
output = torch.sigmoid(torch.mean(conventional_output)); output

tensor(0.5622967603, dtype=torch.float64, grad_fn=<SigmoidBackward0>)

In [28]:
loss = criterion(output, target); loss

RuntimeError: iter.device(arg).is_cuda() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/cuda/Loops.cuh":89, please report a bug to PyTorch. argument 1: expected a CUDA device but found cpu

In [51]:
loss.backward()

In [52]:
conventional_gradients[-1].shape

torch.Size([1, 64, 752, 752])

## Compare the gradients of the input image
Using the saliency argument, we can compute the gradient w.r.t to the input image. If streaming is the same as conventional training, these gradients should be roughly equal

In [53]:
diff = image.grad.detach().cpu().numpy() - sCNN.saliency_map[0].numpy()
print(diff.max())

0.0020155576469186457


## Compare the gradients of the conv2d layers

In [54]:
normal_conv_gradients = []
j = 0
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        if layer.weight.grad is not None:
            normal_conv_gradients.append(layer.weight.grad) 
            print('Conv layer', j, '\t', layer)
            j += 1

Conv layer 0 	 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Conv layer 1 	 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv layer 2 	 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv layer 3 	 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv layer 4 	 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv layer 5 	 Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv layer 6 	 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv layer 7 	 Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
Conv layer 8 	 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv layer 9 	 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
Conv layer 10 	 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1,

In [55]:
print('Conventional', '\n')

for i in range(len(streaming_conv_gradients)):
    print("Conv layer", i, "\t average gradient size:", 
          float(torch.mean(torch.abs(streaming_conv_gradients[i].data))))

Conventional 

Conv layer 0 	 average gradient size: 0.025278384389962636
Conv layer 1 	 average gradient size: 0.1965582114937199
Conv layer 2 	 average gradient size: 0.13338380764486021
Conv layer 3 	 average gradient size: 0.21788707053958006
Conv layer 4 	 average gradient size: 0.12211251266994495
Conv layer 5 	 average gradient size: 0.27629812906776163
Conv layer 6 	 average gradient size: 0.1499612032093409
Conv layer 7 	 average gradient size: 0.5234610677786308
Conv layer 8 	 average gradient size: 0.16657202504835927
Conv layer 9 	 average gradient size: 0.07438151845975027
Conv layer 10 	 average gradient size: 0.06905372369649644
Conv layer 11 	 average gradient size: 0.03219021375690993
Conv layer 12 	 average gradient size: 0.08026865727058952
Conv layer 13 	 average gradient size: 0.055126688412476296
Conv layer 14 	 average gradient size: 0.028606464769841595
Conv layer 15 	 average gradient size: 0.021730748478855963
Conv layer 16 	 average gradient size: 0.005808344

In [56]:

print('Streaming', '\n')
for i in range(len(normal_conv_gradients)):
    print("Conv layer", i, "\t average gradient size:", 
          float(torch.mean(torch.abs(normal_conv_gradients[i].data))))

Streaming 

Conv layer 0 	 average gradient size: 0.025278384389962643
Conv layer 1 	 average gradient size: 0.19655821149371994
Conv layer 2 	 average gradient size: 0.13338380764486021
Conv layer 3 	 average gradient size: 0.21788707053958
Conv layer 4 	 average gradient size: 0.12211251266994494
Conv layer 5 	 average gradient size: 0.27629812906776163
Conv layer 6 	 average gradient size: 0.14996120320934092
Conv layer 7 	 average gradient size: 0.5234610677786309
Conv layer 8 	 average gradient size: 0.16657202504835927
Conv layer 9 	 average gradient size: 0.07438151845975027
Conv layer 10 	 average gradient size: 0.06905372369649644
Conv layer 11 	 average gradient size: 0.03219021375690993
Conv layer 12 	 average gradient size: 0.08026865727058952
Conv layer 13 	 average gradient size: 0.0551266884124763
Conv layer 14 	 average gradient size: 0.028606464769841595
Conv layer 15 	 average gradient size: 0.021730748478855963
Conv layer 16 	 average gradient size: 0.005808344896697

In [57]:
for i in range(len(streaming_conv_gradients)):
    diff = torch.abs(streaming_conv_gradients[i].data - normal_conv_gradients[i].data)
    max_diff = diff.max()
    print("Conv layer", i, "\t max difference between kernel gradients:", 
          float(max_diff))

Conv layer 0 	 max difference between kernel gradients: 5.995204332975845e-15
Conv layer 1 	 max difference between kernel gradients: 1.9095836023552692e-13
Conv layer 2 	 max difference between kernel gradients: 7.061018436615996e-14
Conv layer 3 	 max difference between kernel gradients: 1.6608936448392342e-13
Conv layer 4 	 max difference between kernel gradients: 1.0835776720341528e-13
Conv layer 5 	 max difference between kernel gradients: 4.574118861455645e-14
Conv layer 6 	 max difference between kernel gradients: 1.1546319456101628e-13
Conv layer 7 	 max difference between kernel gradients: 6.483702463810914e-14
Conv layer 8 	 max difference between kernel gradients: 6.483702463810914e-14
Conv layer 9 	 max difference between kernel gradients: 7.016609515630989e-14
Conv layer 10 	 max difference between kernel gradients: 2.3092638912203256e-14
Conv layer 11 	 max difference between kernel gradients: 1.554312234475219e-14
Conv layer 12 	 max difference between kernel gradients: 