In [None]:
def compute_itm_4d(input_mtx,filtersize=3): #for a 4-dimensional input tensor of (BATCHSIZE,CHANNELS,HEIGHT,WIDTH)
    stride=1.5 # algorithm only works for stride value 1.5

    assert input_mtx.dim()==4,\
    "Input tensor dimension is %dD instead of 4D" %input_mtx.dim()
  
    batchsize = input_mtx.size()[0]
    channels = input_mtx.size()[1]
    input_rows = input_mtx.size()[2] #height
    input_cols = input_mtx.size()[3] #width
    
    rows = ((input_rows-filtersize)/stride)+1 #output H dimension
    columns = ((input_cols-filtersize)/stride)+1 #output W dimension
    assert rows%1 == 0 and columns%1 ==0,\
    "Invalid output HxW dimension, current output dimension for HxW is %f x %f" %(rows,columns)
    rows = int(rows)
    columns = int(columns)
    
    output = torch.zeros(batchsize,channels,rows*filtersize,columns*filtersize,device=device)
    
    for batch in range(batchsize):
        for channel in range(channels):                
            itm_mtx = torch.zeros(rows*filtersize,columns*filtersize,device=device) #intermediate per channel to store ALL individual values for computing outputs
            input_current_mtx = input_mtx[batch][channel]          
            for opmap_y in range(rows):       
                for opmap_x in range(columns):            
                    for innermap_y in range(filtersize):   
                        for innermap_x in range(filtersize):
                            strides_x =((opmap_x) * stride)
                            strides_y =((opmap_y) * stride)

                            input_pos_x = innermap_x  + strides_x
                            input_pos_y = innermap_y  + strides_y

                            itmpos_x = ((opmap_x) * filtersize) + innermap_x 
                            itmpos_y = ((opmap_y) * filtersize) + innermap_y

                            if( opmap_y % 2 == 0 and  opmap_x % 2 == 0):
                                itm_mtx[itmpos_y][itmpos_x] = input_current_mtx[int(input_pos_y)][int(input_pos_x)]

                            if( opmap_y % 2 ==0 and  opmap_x % 2 != 0):
                                itm_mtx[itmpos_y][itmpos_x] = (0.5*(input_current_mtx[int(input_pos_y)][int(input_pos_x -0.5)])) + (0.5*(input_current_mtx[int(input_pos_y)][int(input_pos_x +0.5)]))

                            if( opmap_y % 2 != 0 and  opmap_x % 2 ==0):
                                itm_mtx[itmpos_y][itmpos_x] = 0.5*(input_current_mtx[int(input_pos_y -0.5)][int(input_pos_x)]) + 0.5*(input_current_mtx[int(input_pos_y +0.5)][int(input_pos_x)])

                            if( opmap_y % 2 != 0 and  opmap_x % 2 !=0):
                                itm_mtx[itmpos_y][itmpos_x] = 0.25*(input_current_mtx[int(input_pos_y -0.5)][int(input_pos_x-0.5)]) + 0.25*(input_current_mtx[int(input_pos_y +0.5)][int(input_pos_x-0.5)])\
                                +0.25*(input_current_mtx[int(input_pos_y-0.5)][int(input_pos_x+0.5)]) + 0.25*(input_current_mtx[int(input_pos_y+0.5)][int(input_pos_x+0.5)])    
            output[batch][channel] = itm_mtx
    return output

In [None]:
def better_compute(input_mtx,filtersize = 3):
    #use with dilated convolution
    
    stride=1.5 # algorithm only works for stride value 1.5

    assert input_mtx.dim()==4,\
    "Input tensor dimension is %dD instead of 4D" %input_mtx.dim()
  
    batchsize = input_mtx.size()[0]
    channels = input_mtx.size()[1]
    input_rows = input_mtx.size()[2] #height
    input_cols = input_mtx.size()[3] #width
    
    rows = ((input_rows-filtersize)/stride)+1 #output H dimension
    columns = ((input_cols-filtersize)/stride)+1 #output W dimension
    assert rows%1 == 0 and columns%1 ==0,\
    "Invalid output HxW dimension, current output dimension for HxW is %f x %f" %(rows,columns) #safety check
  #  rows = int(rows)
   # columns = int(columns)
    new_rows = (2*input_rows)-1
    new_cols =  (2*input_cols)-1
    
    output = torch.zeros(batchsize,channels,input_rows,new_cols,device=device)
    output2 = torch.zeros(batchsize,channels,new_rows,new_cols,device=device)
    
    #im gonna put rows and cols operations together here, input matrix should be a square anyway, so input_rows = input_cols
    for batch in range(batchsize):
        for chl in range(channels):
            for width in range(input_cols): #copy values into horizontally expanded matrix
                output[batch][chl][:,2*width] = input_mtx[batch][chl][:,width]
                
            for cols in range(new_cols):   #compute in-between values
                if cols%2 !=0:
                    output[batch][chl][:,cols] = (output[batch][chl][:,cols-1] + output[batch][chl][:,cols+1])/2
            
            for height in range(input_rows): #copy values into vertically expanded matrix
                output2[batch][chl][2*height] = output[batch][chl][height]
             
            for rows in range(new_rows):
                if rows%2 !=0:
                    output2[batch][chl][rows] = (output2[batch][chl][rows-1] + output2[batch][chl][rows+1])/2

    return output2
    

    
    

In [None]:
def better_compute2(input_mtx,filtersize = 3):
    #use with dilated convolution
    
    stride=1.5 # algorithm only works for stride value 1.5

    assert input_mtx.dim()==4,\
    "Input tensor dimension is %dD instead of 4D" %input_mtx.dim()
  
    batchsize = input_mtx.size()[0]
    channels = input_mtx.size()[1]
    input_rows = input_mtx.size()[2] #height
    input_cols = input_mtx.size()[3] #width
    
    rows = ((input_rows-filtersize)/stride)+1 #output H dimension
    columns = ((input_cols-filtersize)/stride)+1 #output W dimension
    assert rows%1 == 0 and columns%1 ==0,\
    "Invalid output HxW dimension, current output dimension for HxW is %f x %f" %(rows,columns) #safety check
  #  rows = int(rows)
   # columns = int(columns)
    new_rows = (2*input_rows)-1
    new_cols =  (2*input_cols)-1
    
    itm1 = torch.zeros(input_rows,new_cols,device=device,requires_grad=False)
    itm2 = torch.zeros(new_rows,new_cols,device=device,requires_grad=False)
    output2 = torch.zeros(batchsize,channels,new_rows,new_cols,device=device,requires_grad=True)
    
    for batch in range(batchsize):
        for chl in range(channels):
            for width in range(input_cols): #copy values into horizontally expanded matrix
                itm1[:,2*width] = input_mtx[batch][chl][:,width]
                
            for cols in range(new_cols):   #compute in-between values
                if cols%2 !=0:
                    itm1[:,cols] = (itm1[:,cols-1] + itm1[:,cols+1])/2
            
            for height in range(input_rows): #copy values into vertically expanded matrix
                itm2[2*height] = itm1[height]
             
            for rows in range(new_rows):
                if rows%2 !=0:
                    itm2[rows] = (itm2[rows-1] + itm2[rows+1])/2
            
            output2[batch][chl] = itm2
        
    return output2
    
    

    
    