In [106]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        # 21 input image channel, 21 output channels, 5x5 square convolution kernel
        self.conv1 = nn.Conv3d(21, 21, kernel_size=(5, 5, 1), stride=1)

    def forward(self, x):
        x = F.max_pool3d(F.relu(self.conv1(x)), 1)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)
print(net.conv1.weight.data.size())

Net(
  (conv1): Conv3d(21, 21, kernel_size=(5, 5, 1), stride=(1, 1, 1))
)
torch.Size([21, 21, 5, 5, 1])


In [11]:
params = list(net.conv1.parameters())
print(params[0].size())  # conv1's .weight
print(params[0])

torch.Size([21, 21, 5, 5, 1])
Parameter containing:
tensor([[[[[-0.0189],
           [ 0.0418],
           [-0.0013],
           [-0.0110],
           [ 0.0312]],

          [[-0.0382],
           [-0.0354],
           [ 0.0123],
           [ 0.0008],
           [-0.0324]],

          [[ 0.0217],
           [ 0.0136],
           [ 0.0196],
           [ 0.0408],
           [-0.0204]],

          [[-0.0130],
           [ 0.0185],
           [-0.0434],
           [-0.0324],
           [-0.0258]],

          [[-0.0321],
           [ 0.0220],
           [ 0.0086],
           [-0.0028],
           [ 0.0144]]],


         [[[-0.0266],
           [-0.0019],
           [-0.0418],
           [ 0.0419],
           [ 0.0192]],

          [[ 0.0431],
           [ 0.0435],
           [ 0.0192],
           [ 0.0252],
           [ 0.0266]],

          [[ 0.0141],
           [-0.0203],
           [-0.0337],
           [ 0.0385],
           [ 0.0079]],

          [[-0.0408],
           [ 0.0426],
      

           [-0.0064]]]]], requires_grad=True)


In [98]:
input = torch.zeros(21, 21, 32, 32, 3)
out = net(input)
print(len(out))
print(out.size())

test = (torch.randn(3, 5, 5)*5)**2 # color, height, width
print(test)
print(test[1][0][0])
print(len(test))

21
torch.Size([21, 21, 28, 28, 3])
tensor([[[3.7147e+01, 9.7168e+00, 8.0916e+01, 6.0657e+00, 1.5412e+01],
         [1.4927e+01, 4.8260e+00, 1.1910e+00, 7.9911e-01, 2.1111e+01],
         [3.8280e+00, 1.3973e+01, 1.0502e+02, 1.6421e+01, 2.0660e+01],
         [2.9112e+01, 1.7134e+01, 2.2576e+01, 2.3245e+01, 1.4811e+01],
         [6.0025e+00, 1.6040e+00, 1.0343e+01, 3.8109e+01, 5.3502e+00]],

        [[1.8206e+02, 6.8341e+00, 2.7072e-02, 3.4774e-03, 4.2575e+01],
         [4.7714e-02, 1.1960e+00, 1.1758e+01, 5.0954e+00, 2.6628e+01],
         [1.4508e+01, 3.2707e+00, 2.1422e-01, 5.0401e+00, 4.7179e+01],
         [1.4935e+00, 2.4420e+01, 4.4317e+00, 1.9669e+02, 2.8288e+00],
         [2.0904e+00, 2.0479e+01, 2.6942e+00, 4.7274e+01, 2.7179e+01]],

        [[2.0904e+00, 6.1536e+00, 7.8220e+00, 6.5054e-01, 2.0847e+00],
         [8.8162e-03, 1.1711e+02, 4.0880e-01, 4.7216e+01, 2.2093e+01],
         [4.3094e+00, 1.3744e+02, 1.7386e-01, 4.1698e+01, 1.3194e+01],
         [7.3472e+00, 3.7885e+00, 1.16

In [121]:
def calcDistance(color1, color2, color3, compColor1, compColor2, compColor3, imageX, imageY, x, y, middle):
    colorWeight = 0.5
    spatialWeight = 0.5
    distance = 0
    distance += colorWeight * math.sqrt((color1-compColor1)**2 + (color2-compColor2)**2 + (color3-compColor3)**2)
    distance += spatialWeight * math.sqrt((imageX-(imageX-middle+x))**2 + (imageY-(imageY-middle+y))**2)
    return distance

In [122]:
def initializeFilters(image):
    imageHeight = len(image[0])
    imageWidth = len(image[0][0])
    filterHeight = 3
    filterWidth = 3
    filters = torch.zeros(imageHeight, imageWidth, filterHeight, filterWidth) #imageY, imageX, y, x
    middle = math.floor(filterHeight/2)
    for imageY in range(imageHeight):
        for imageX in range(imageWidth):
            for y in range(filterHeight):
                for x in range(filterWidth):
                    if not ((imageX + (x-middle)) < 0 or (imageX + x) > imageWidth or (imageY + (y-middle)) < 0 or (imageY + y) > imageHeight):
                        filters[imageY][imageX][y][x] = calcDistance(image[0][imageY][imageX], 
                                                                      image[1][imageY][imageX], 
                                                                      image[2][imageY][imageX], 
                                                                      image[0][imageY - middle + y][imageX - middle + x], 
                                                                      image[1][imageY - middle + y][imageX - middle + x], 
                                                                      image[2][imageY - middle + y][imageX - middle + x],
                                                                      imageX, imageY, x, y, middle)
                    else:
                        filters[imageY][imageX][y][x] = 0
    return filters

In [135]:
filters = initializeFilters(test)
print(filters.size())
print(filters.view(25, 3, 3).size())
print(filters.view(25, 3, 3))

torch.Size([5, 5, 3, 3])
torch.Size([25, 3, 3])
tensor([[[  0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,  89.2028],
         [  0.0000,  92.1873, 109.0887]],

        [[  0.0000,   0.0000,   0.0000],
         [ 89.2028,   0.0000,  36.2718],
         [  5.9739,  56.1049,   6.4067]],

        [[  0.0000,   0.0000,   0.0000],
         [ 36.2718,   0.0000,  38.0968],
         [ 67.2945,  40.9621,  45.4182]],

        [[  0.0000,   0.0000,   0.0000],
         [ 38.0968,   0.0000,  22.3046],
         [  7.0710,  24.0689,  19.3821]],

        [[  0.0000,   0.0000,   0.0000],
         [ 22.3046,   0.0000,   0.0000],
         [ 30.9358,  13.6066,   0.0000]],

        [[  0.0000,  92.1873,   5.9739],
         [  0.0000,   0.0000,  59.2721],
         [  0.0000,   9.8647,  69.4410]],

        [[109.0887,  56.1049,  67.2945],
         [ 59.2721,   0.0000,  59.1186],
         [ 57.5023,  11.6914,  77.7035]],

        [[  6.4067,  40.9621,   7.0710],
         [ 59.1186,   0.0000,  24.14