In [4]:
import torch
from lightstream.core.scnn.scnn import StreamingCNN, StreamingConv2d

/usr/local/lib/python3.12/dist-packages/lightning/fabric/__init__.py:41: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
torch.set_printoptions(precision=10)

## Model definition

In [6]:
padding = 0

stream_net = torch.nn.Sequential(
    torch.nn.Conv2d(3, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2))


In [7]:
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        layer.weight.data *= 1.0
        
        if layer.bias is not None:
            layer.bias.data.zero_()

In [8]:
print(stream_net)

Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (11): ReLU()
  (12): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (13): ReLU()
  (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


## Configurations

In [9]:
tile_size = 128*15
img_size = 128*25

cuda = True  # execute this notebook on the GPU
verbose = True   # enable / disable logging
dtype = torch.float64  # test with double precision

In [10]:

stream_net.type(dtype)
if cuda: stream_net.cuda()

## Configure streamingCNN
IMPORTANT: setting gather_gradients to True makes the class save all the gradients of the intermediate feature maps. This is needed because we want to compare the feature map gradients between streaming and conventional backpropagation. However this also counteracts the memory gains by StreamingCNN. If you want to test the memory efficiency, set gather_gradients to False

In [11]:
sCNN = StreamingCNN(stream_net, 
                    tile_shape=(1, 3, tile_size, tile_size), 
                    verbose=True,
                    saliency=True)

Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)

 Output lost Lost(top:0.0, left:0.

If the verbose flag is True than StreamingCNN will print for every layer in the network the required overlap that is needed to reconstruct the feature maps and gradients. The higher this is, the more tiles are needed to be inferences. It is always beneficial to increase the tile size as much as possible to make use of all the GPU memory.

## Generate random image and fake label

In [12]:
image = torch.FloatTensor(3, img_size, img_size).normal_(0, 1)
target = torch.tensor(50.)  # large value so we get larger gradients

image = image.type(dtype)
target = target.type(dtype)

if cuda:
    target = target.cuda()
    image = image.cuda()

In [13]:
criterion = torch.nn.MSELoss()

## Run through network using streaming

In [14]:
stream_output = sCNN.forward(image[None])
print(stream_output.shape)
stream_output.max()

torch.Size([1, 16, 396, 396])


tensor(0.0553336525, device='cuda:0', dtype=torch.float64)

In [15]:
stream_output.requires_grad = True

In [16]:
output = torch.sigmoid(torch.mean(stream_output)); output

tensor(0.5024655418, device='cuda:0', dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)

In [17]:
loss = criterion(output, target)
loss

tensor(2450.0059174429, device='cuda:0', dtype=torch.float64,
       grad_fn=<MseLossBackward0>)

In [18]:
loss.backward()

In [19]:
print(stream_output.shape)
print(stream_output.grad.shape)
full_gradients = sCNN.backward(image[None], stream_output.grad)

torch.Size([1, 16, 396, 396])
torch.Size([1, 16, 396, 396])


In [22]:
sCNN.saliency_map.shape

torch.Size([1, 3, 3200, 3200])

In [23]:
streaming_conv_gradients = []

for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, StreamingConv2d):
        if layer.weight.grad is not None:
            streaming_conv_gradients.append(layer.weight.grad.clone()) 

In [24]:
sCNN.disable()

## Compare to conventional training

In [25]:
stream_net.type(dtype)
if cuda: stream_net.cuda()

for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        if layer.weight.grad is not None:
            layer.weight.grad.data.zero_()
            layer.bias.grad.data.zero_()
            


In [26]:
conventional_gradients = []
inps = []

def save_grad(module, grad_in, grad_out):
    global conventional_gradients
    conventional_gradients.append(grad_out[0].clone())
        
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        layer.register_backward_hook(save_grad)

This output should be the same as the streaming output, if so, the loss will also be the same:



In [27]:
image.requires_grad = True
conventional_output = stream_net(image[None]); conventional_output.max()
conventional_output.shape

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


torch.Size([1, 16, 396, 396])

In [28]:
print(conventional_output.shape)
stream_output.shape

torch.Size([1, 16, 396, 396])


torch.Size([1, 16, 396, 396])

In [29]:
# NOTE: sometimes output can be slightly bigger 
# (if tiles do not fit nicely on input image according to output stride)
# In that case this check may fail.
print(stream_output.shape, conventional_output.shape)
max_error = torch.abs(stream_output.detach().cpu() - conventional_output.detach().cpu()).max().item()

if max_error < 1e-7:
    print("Equal output to streaming")
else:
    print("NOT equal output to streaming"),
    print("error:", max_error)

torch.Size([1, 16, 396, 396]) torch.Size([1, 16, 396, 396])
Equal output to streaming


In [30]:
output = torch.sigmoid(torch.mean(conventional_output)); output

tensor(0.5024655418, device='cuda:0', dtype=torch.float64,
       grad_fn=<SigmoidBackward0>)

In [31]:
loss = criterion(output, target); loss

tensor(2450.0059174429, device='cuda:0', dtype=torch.float64,
       grad_fn=<MseLossBackward0>)

In [32]:
loss.backward()

In [33]:
conventional_gradients[-1].shape

torch.Size([1, 16, 3198, 3198])

## Compare the gradients of the input image
Using the saliency argument, we can compute the gradient w.r.t to the input image. If streaming is the same as conventional training, these gradients should be roughly equal

In [34]:
diff = image.grad.detach().cpu().numpy() - sCNN.saliency_map[0].numpy()
print(diff.max())

1.3234889800848443e-22


## Compare the gradients of the conv2d layers

In [35]:
normal_conv_gradients = []
j = 0
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        if layer.weight.grad is not None:
            normal_conv_gradients.append(layer.weight.grad) 
            print('Conv layer', j, '\t', layer)
            j += 1

Conv layer 0 	 Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 1 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 2 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 3 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 4 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 5 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))


In [36]:
print('Conventional', '\n')

for i in range(len(streaming_conv_gradients)):
    print("Conv layer", i, "\t average gradient size:", 
          float(torch.mean(torch.abs(streaming_conv_gradients[i].data))))

Conventional 

Conv layer 0 	 average gradient size: 0.005337106989303007
Conv layer 1 	 average gradient size: 0.011621452477908677
Conv layer 2 	 average gradient size: 0.016327254261941933
Conv layer 3 	 average gradient size: 0.014050563924988376
Conv layer 4 	 average gradient size: 0.012838129856554763
Conv layer 5 	 average gradient size: 0.02153204154635157


In [37]:

print('Streaming', '\n')
for i in range(len(normal_conv_gradients)):
    print("Conv layer", i, "\t average gradient size:", 
          float(torch.mean(torch.abs(normal_conv_gradients[i].data))))

Streaming 

Conv layer 0 	 average gradient size: 0.005337106989303007
Conv layer 1 	 average gradient size: 0.011621452477908679
Conv layer 2 	 average gradient size: 0.01632725426194193
Conv layer 3 	 average gradient size: 0.014050563924988388
Conv layer 4 	 average gradient size: 0.012838129856554773
Conv layer 5 	 average gradient size: 0.021532041546351636


In [38]:
for i in range(len(streaming_conv_gradients)):
    diff = torch.abs(streaming_conv_gradients[i].data - normal_conv_gradients[i].data)
    max_diff = diff.max()
    print("Conv layer", i, "\t max difference between kernel gradients:", 
          float(max_diff))

Conv layer 0 	 max difference between kernel gradients: 1.7416623698807143e-15
Conv layer 1 	 max difference between kernel gradients: 3.802513859341161e-15
Conv layer 2 	 max difference between kernel gradients: 5.405398351143731e-15
Conv layer 3 	 max difference between kernel gradients: 5.703770789011742e-15
Conv layer 4 	 max difference between kernel gradients: 9.020562075079397e-15
Conv layer 5 	 max difference between kernel gradients: 3.552713678800501e-15
