In [1]:
# im2col
# 채널 수 : 1, 배치 사이즈 : 1
# 패딩하지않고 스트라이드 : 1 인경우
import numpy as np

In [2]:
def im2col(image, flt_h, flt_w, out_h, out_w):
    img_h, img_w = image.shape # 입력이미지의 높이, 넓이
    
    cols = np.zeros((flt_h * flt_w, out_h * out_w)) # 생성되는 행렬 크기
    
    for h in range(out_h):
        h_lim = h + flt_h
        for w in range(out_w):
            w_lim = w + flt_w
            cols[:, h * out_w + w] = image[h : h_lim, w : w_lim].reshape(-1)
            
    return cols

In [4]:
img = np.array([[1,2,3,4],
               [5,6,7,8],
               [9,10,11,12],
               [13,14,15,16]])
cols = im2col(img, 2, 2, 3, 3)
print(cols)

[[ 1.  2.  3.  5.  6.  7.  9. 10. 11.]
 [ 2.  3.  4.  6.  7.  8. 10. 11. 12.]
 [ 5.  6.  7.  9. 10. 11. 13. 14. 15.]
 [ 6.  7.  8. 10. 11. 12. 14. 15. 16.]]


In [5]:
# im2col 속도 개선 : for문을 필터기준으로 돌림
def im2col(image, flt_h, flt_w, out_h, out_w):
    img_h, img_w = image.shape
    cols = np.zeros((flt_h, flt_w, out_h, out_w))
    for h in range(flt_h):
        h_lim = h + out_h
        for w in range(flt_w):
            w_lim = w + out_w
            cols[h, w, :, :] = image[h : h_lim, w : w_lim]
            
    cols = cols.reshape(flt_h * flt_w, out_h * out_w)
    
    return cols

In [6]:
# 배치, 채널 고려
def im2col(images, flt_h, flt_w, out_h, out_w):
    # 배치사이즈, 채널수, 입력이미지높이, 너비
    n_bt, n_ch, img_h, img_w = images.shape
    
    col = np.zeros((n_bt, n_ch, flt_h, flt_w, out_h, out_w))
    
    for h in range(flt_h):
        h_lim = h + out_h
        for w in range(flt_w):
            w_lim = w + out_w
            cols[:, :, h, w, :, :] = images[:, :, h:h_lim, w:wlim]
            
    cols = cols.transpose(1,2,3,0,4,5).reshape(
            n_ch * flt_h * flt_w, n_bt * out_h * out_w)
    return cols

In [7]:
# 패딩, 스트라이드 고려
def im2col(images, flt_h, flt_w, out_h, out_w, stride, pad):
    n_bt, n_ch, img_h, img_w = images.shape
    
    img_pad = np.pad(images, [(0,0), (0,0), (pad,pad), (pad,pad)], "constant")
    cols = np.zeros((n_bt, n_ch, flt_h, flt_w, out_h, out_w))
    
    for h in range(flt_h):
        h_lim = h + stride*out_h
        for w in range(flt_w):
            w_lim = w + stride * out_w
            cols[:, :, h, w, :, :] = img_pad[:, :, h:h_lim:stride, w:w_lim:stride]
            
    cols = cols.transpose(1,2,3,0,4,5).reshape(n_ch * flt_h * flt_w, n_bt*out_h*out_w)
    
    return cols

In [8]:
img = np.array([[[[1,2,3,4],
               [5,6,7,8],
               [9,10,11,12],
               [13,14,15,16]]]])
cols = im2col(img, 2, 2, 3, 3, 1, 0)
print(cols)

[[ 1.  2.  3.  5.  6.  7.  9. 10. 11.]
 [ 2.  3.  4.  6.  7.  8. 10. 11. 12.]
 [ 5.  6.  7.  9. 10. 11. 13. 14. 15.]
 [ 6.  7.  8. 10. 11. 12. 14. 15. 16.]]


In [9]:
# col2im 구현
def col2im(cols, img_shape, flt_h, flt_w, out_h, out_w, stride, pad):
    
    n_bt, n_ch, img_h, img_w = img_shape
    
    cols = cols.reshape(n_ch, flt_h, flt_w, n_bt, out_h, out_w).transpose(3,0,1,2,4,5)
    images = np.zeros((n_bt, n_ch, img_h+2*pad+stride-1, img_w+2*pad+stride-1))
    
    for h in range(flt_h):
        h_lim = h + stride * out_h
        for w in range(flt_w):
            w_lim = w + stride * out_w
            images[:, :, h:h_lim:stride, w:w_lim:stride] += cols[:, :, h, w, :, :]
            
    return images[:, :, pad:img_h+pad, pad:img_w+pad]

In [10]:
cols = np.ones((4,4))
img_shape = (1,1,3,3)
images = col2im(cols, img_shape, 2, 2, 2, 2, 1, 0)
print(images)

[[[[1. 2. 1.]
   [2. 4. 2.]
   [1. 2. 1.]]]]
