In [1]:
import numpy as np
np.random.seed(1)

# 2 im2col関数とcol2im関数

In [2]:
def im2col(x, fil_size, y_size, stride, pad):
    x_b, x_c, x_h, x_w = x.shape
    fil_h, fil_w = fil_size, fil_size
    y_h, y_w = y_size, y_size
    index = -1
    
    x_pad = np.pad(x, [(0, 0), (0, 0), (pad, pad), (pad, pad)], "constant")
    x_col = np.zeros((fil_h*fil_w, x_b, x_c, y_h, y_w))
    
    for h in range(fil_h):
        h2 = h + y_h*stride
        for w in range(fil_w):
            index += 1
            w2 = w + y_w*stride
            x_col[index,:,:,:,:] = x_pad[:,:,h:h2:stride,w:w2:stride]
    x_col = x_col.transpose(2,0,1,3,4).reshape(x_c*fil_h*fil_w, x_b*y_h*y_w)
    
    return x_col

In [3]:
def col2im(dx_col, x_shape, fil_size, y_size, stride, pad):
    x_b, x_c, x_h, x_w = x_shape
    fil_h, fil_w = fil_size, fil_size
    y_h, y_w = y_size, y_size
    index = -1
    
    dx_col = dx_col.reshape(x_c, fil_h*fil_w, x_b, y_h, y_w).transpose(1,2,0,3,4)
    dx = np.zeros((x_b, x_c, x_h+2*pad+stride-1, x_w+2*pad+stride-1))
    
    for h in range(fil_h):
        h2 = h + y_h*stride
        for w in range(fil_w):
            index += 1
            w2 = w + y_w*stride
            dx[:,:,h:h2:stride,w:w2:stride] += dx_col[index,:,:,:,:]
    
    return dx[:,:,pad:x_h+pad, pad:x_w+pad]

# 3 convolutional層

In [4]:
class Conv:
    
    def __init__(self, x_c, y_c, fil_size, stride, pad):
        self.x_c, self.y_c = x_c, y_c
        self.fil_h, self.fil_w = fil_size, fil_size
        self.stride, self.pad = stride, pad
        
        self.w = np.arange(54).reshape(2,3,3,3)
        self.b = np.zeros((1,self.y_c))
        #self.w = np.random.randn(self.y_c, self.x_c, self.fil_h, self.fil_w)
        #self.b = np.random.randn(1, self.y_c)
        
    def forward(self, x):
        
        self.xshape = x.shape
        self.x_b, self.x_c, self.x_h, self.x_w = x.shape
        self.y_h = (self.x_h - self.fil_h + 2*self.pad) // self.stride + 1
        self.y_w = (self.x_w - self.fil_w + 2*self.pad) // self.stride + 1
        
        self.x_col = im2col(x, self.fil_h, self.y_h, self.stride, self.pad)
        self.w_col = self.w.reshape(self.y_c, self.x_c*self.fil_h*self.fil_w)
        
        y = np.dot(self.w_col, self.x_col).T + self.b
        self.y = y.reshape(self.x_b, self.y_h, self.y_w, self.y_c).transpose(0,3,1,2)
        
        return self.y
    
    def backward(self, dy):
        
        dy = dy.transpose(0,2,3,1).reshape(self.x_b*self.y_h*self.y_w, self.y_c)
        dw = np.dot(self.x_col, dy)
        
        self.dw = dw.T.reshape(self.y_c, self.x_c, self.fil_h, self.fil_w)
        self.db = np.sum(dy, axis=0)
        
        dx_col = np.dot(dy, self.w_col)
        self.dx = col2im(dx_col.T, (self.xshape), self.fil_h, self.y_h, self.stride, self.pad)
        
        return self.dx

In [5]:
x = np.arange(96).reshape(2,3,4,4)
x.shape, x

((2, 3, 4, 4),
 array([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]],
 
         [[16, 17, 18, 19],
          [20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31]],
 
         [[32, 33, 34, 35],
          [36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]]],
 
 
        [[[48, 49, 50, 51],
          [52, 53, 54, 55],
          [56, 57, 58, 59],
          [60, 61, 62, 63]],
 
         [[64, 65, 66, 67],
          [68, 69, 70, 71],
          [72, 73, 74, 75],
          [76, 77, 78, 79]],
 
         [[80, 81, 82, 83],
          [84, 85, 86, 87],
          [88, 89, 90, 91],
          [92, 93, 94, 95]]]]))

In [6]:
conv = Conv(3, 2, 3, 1, 0)

In [7]:
y = conv.forward(x)
y.shape, y

((2, 2, 2, 2),
 array([[[[10197., 10548.],
          [11601., 11952.]],
 
         [[25506., 26586.],
          [29826., 30906.]]],
 
 
        [[[27045., 27396.],
          [28449., 28800.]],
 
         [[77346., 78426.],
          [81666., 82746.]]]]))

In [8]:
conv.backward(np.ones(y.shape))

array([[[[ 27.,  56.],
         [ 60., 124.]],

        [[ 45.,  92.],
         [ 96., 196.]],

        [[ 63., 128.],
         [132., 268.]]],


       [[[ 27.,  56.],
         [ 60., 124.]],

        [[ 45.,  92.],
         [ 96., 196.]],

        [[ 63., 128.],
         [132., 268.]]]])