In [2]:
import torch
import torch.nn as nn
import time

In [3]:
device = torch.device('cpu')

## Dilated convolutions on an intermediate map is effectively the same as convolution with stride = 1.5
It is likely this makes backpropagation slower because of the extra operations for processing  the original matrix

Only even multiples of 1.5 end in integers. So filters must convolve odd(1(zero indexed position) + even times across height & width) times along height and width  
Along output map HxW(1-indexed):  
If index H or W == even, outmap will contain combined pixel pairs +1  
i.e (1,1) 0 pair(raw input)  
    (1,2) 1 pair(along 2nd dimension)  
    (2,2) 2 pairs(along both dimensions)  

\begin{align}
\frac{3n-3}{1.5} + 1  \\
= \frac {3(n-1)}{0.5(3)} + 1  \\
= 2(n-1)+1 \\
= 2n-2+1 \\
= 2n-1 
\end{align}

#### For all integers n > 0, 2n-1 always gives positive odd numbers(for H ^ W), where stride = 1.5 and filter size = 3 and input size is multiple of 3
#### But not all odd numbers are multiples of 3.. :(

### Make cheaper one
The output from better_compute is meant to run with dilated convolution, dilation = 1
Refer to ** https://ezyang.github.io/convolution-visualizer/index.html ** for visualization
The output from better_compute passed into a convolution layer with kernel size 3, dilation 2, stride 3, is equivalent to a stride 1.5 convolution with kernel size 3 on the input into better_compute

Based on observation
1. Insert empty columns inbetween, calculate value from adjacent left and right values
2. Insert empty rows inbetween, calculate values from adjacent top and bottom values
**
3. Run dilated convolution on this new feature map, equiv to stride 1.5 conv

### Final version. Decent speed for forward and backward prop

In [4]:
def better_compute3(input_mtx,filtersize = 3):
    #use with dilated convolution
    #used to output intermediate, then convolve with dilated convolution to effect stride 1.5 on input
    stride=1.5 # algorithm only works for stride value 1.5

    assert input_mtx.dim()==4,\
    "Input tensor dimension is %dD instead of 4D" %input_mtx.dim()
  
    batchsize = input_mtx.size()[0]
    channels = input_mtx.size()[1]
    input_rows = input_mtx.size()[2] #height
    input_cols = input_mtx.size()[3] #width
    
    rows = ((input_rows-filtersize)/stride)+1 #output H dimension
    columns = ((input_cols-filtersize)/stride)+1 #output W dimension
    assert rows%1 == 0 and columns%1 ==0,\
    "Invalid output HxW dimension, current output dimension for HxW is %f x %f" %(rows,columns) #safety check
    
    new_rows = (2*input_rows)-1 #itm H dimension
    new_cols =  (2*input_cols)-1 #itm W dimension
    
    itm1 = torch.zeros(input_rows,new_cols,device=device,requires_grad=False)
    itm2 = torch.zeros(new_rows,new_cols,device=device,requires_grad=False)
    output2 = torch.zeros(batchsize,channels,new_rows,new_cols,device=device,requires_grad=False)
    
    for batch in range(batchsize):
        for chl in range(channels):
            itm1[:,::2] = input_mtx[batch][chl][:,::1] #fill in alternating columns
            itm1[:,1:-1:2] = (itm1[:,0:-1:2] +itm1[:,2:new_rows:2])/2     #calculate inbetween column values
            itm2[::2,:] = itm1[::1,:] #fill in alternating rows
            itm2[1:-1:2,:] = (itm2[0:-1:2,:] + itm2[2:new_cols:2,:])/2 #calculate inbetween row values
            output2[batch][chl] = itm2
        
    return output2
    
    