# Exercise 2

1. Install `scipy` if you don't have it installed
2. Implement `im2col_matrix_sparse` using sparse matrix

In [None]:
import numpy as np
from scipy.sparse import csr_matrix
import time

In [None]:


#     ###### from the slides:
#     # data = [1,1,1,1,1,...] # len(data) = P * patch_size
#     # row_indices = [0,1,3,4,9,...]
#     # col_indices = [0,1,2,3,4,...]
#     # im2col_mat_sparse = csr_matrix((data,
#     #                                 (row_indices, col_indices)),
#     #                                shape=(n_rows, n_cols))


# # im2col
# def im2col_matrix_sparse(Xin, K, S=1):
#     N, Cin, Hin, Win = Xin.shape
#     CHW = Cin * Hin * Win
#     Hout = (Hin - K)//S + 1
#     Wout = (Win - K)//S + 1
#     P = Hout * Wout  # Total number of patches per image
#     patch_size = Cin * K * K # Size of each flattened patch

#     # YOUR CODE HERE

#     # n_rows = ?
#     # n_cols = ?
#     # data = ?
#     # row_indices = ?
#     # col_indices = ?

#     im2col_mat_sparse = csr_matrix((data, (row_indices, col_indices)), shape=(n_rows, n_cols))
#     return im2col_mat_sparse





## SOLUTION

In [None]:
# im2col
def im2col_matrix_sparse(Xin, K, S=1):
    N, Cin, Hin, Win = Xin.shape
    CHW = Cin * Hin * Win
    Hout = (Hin - K)//S + 1
    Wout = (Win - K)//S + 1
    P = Hout * Wout  # Total number of patches per image
    patch_size = Cin * K * K # Size of each flattened patch

    data = [1 for _ in range(P*patch_size)]
    row_indices = []
    col_indices = list(range(P*patch_size))

    patch_idx = 0
    for hout in range(Hout):
        for wout in range(Wout):
            for cin in range(Cin):
                for hker in range(K):
                    for wker in range(K):
                        input_index = cin * Hin * Win + hout * S * Win + wout * S + hker * Win + wker
                        row_indices.append(input_index)
            patch_idx += 1

    im2col_mat_sparse = csr_matrix((data, (row_indices, col_indices)), shape=(CHW, P * patch_size))
    return im2col_mat_sparse



# Paste your `im2col_matrix_dense` here:

In [None]:
# DELETE
# im2col for dense matrix
def im2col_matrix_dense(Xin, K, S=1):
    N, Cin, Hin, Win = Xin.shape
    Hout, Wout = (Hin - K)//S + 1, (Win - K)//S + 1
    P = Hout * Wout  # Total number of patches per image
    patch_size = Cin * K * K # Size of each flattened patch
    im2col_mat_dense = np.zeros((Cin * Hin * Win, P * patch_size))
    patch_idx = 0
    output_index = 0
    for hout in range(Hout):
        for wout in range(Wout):
            for cin in range(Cin):
                for hker in range(K):
                    for wker in range(K):
                        input_index = cin * Hin * Win + hout * S * Win + wout * S + hker * Win + wker
                        im2col_mat_dense[input_index, output_index] = 1
                        output_index += 1
            patch_idx += 1
    return im2col_mat_dense


# Defining sparse matrix

In [None]:
data = [1,2,3,1]
row_indices = [1,2,0,1]
col_indices = [0,1,2,1]
sparse_mat_example = csr_matrix((data, (row_indices, col_indices)), shape=(3, 3))
sparse_mat_example.toarray()

array([[0, 0, 3],
       [1, 1, 0],
       [0, 2, 0]])

# Matrix multiplication

In [None]:
X = np.arange(9).reshape(3,3)
X                   @   sparse_mat_example


array([[ 1,  5,  0],
       [ 4, 14,  9],
       [ 7, 23, 18]])

In [None]:
Xin = np.arange(80*1*32*32).reshape(80,1,32,32)

N, Cin, Hin, Win = Xin.shape
K = 5
S = 1

Hout, Wout = Hin - K + 1, Win - K + 1
P = Hout * Wout
patch_size = Cin * K * K

Xin_flat = Xin.reshape(-1, Cin * Hin * Win)


im2col_mat_dense = im2col_matrix_dense(Xin, K,S) # im2col
im2col_mat_sparse = im2col_matrix_sparse(Xin, K,S) # im2col

# Dense method
start_dense = time.time()
Xin_im2col_dense = Xin_flat @ im2col_mat_dense
end_dense = time.time()
print(f"Dense matmul took {end_dense - start_dense:.6f} seconds, excluding im2col creation.")




# Sparse method
start_sparse = time.time()
Xin_im2col_sparse = Xin_flat @ im2col_mat_sparse
end_sparse = time.time()
print(f"Sparse matmul took {end_sparse - start_sparse:.6f} seconds, excluding im2col creation.")



Dense matmul took 0.374793 seconds, excluding im2col creation.
Sparse matmul took 0.013100 seconds, excluding im2col creation.


# Sanity check: the two methods agree

In [None]:
Xin_patches_flat_dense = Xin_im2col_dense.reshape(N, P, patch_size)
Xin_patches_dense = Xin_patches_flat_dense.reshape(N, P, Cin, K, K)

print("Dense method:  Showing the first two patches of X[0]:")
print(Xin_patches_dense[0][:2])


Xin_patches_flat_sparse = Xin_im2col_sparse.reshape(N, P, patch_size)
Xin_patches_sparse = Xin_patches_flat_sparse.reshape(N, P, Cin, K, K)

print("Sparse method: Showing the first two patches of X[0]:")
print(Xin_patches_sparse[0][:2])


Dense method:  Showing the first two patches of X[0]:
[[[[  0.   1.   2.   3.   4.]
   [ 32.  33.  34.  35.  36.]
   [ 64.  65.  66.  67.  68.]
   [ 96.  97.  98.  99. 100.]
   [128. 129. 130. 131. 132.]]]


 [[[  1.   2.   3.   4.   5.]
   [ 33.  34.  35.  36.  37.]
   [ 65.  66.  67.  68.  69.]
   [ 97.  98.  99. 100. 101.]
   [129. 130. 131. 132. 133.]]]]
Sparse method: Showing the first two patches of X[0]:
[[[[  0   1   2   3   4]
   [ 32  33  34  35  36]
   [ 64  65  66  67  68]
   [ 96  97  98  99 100]
   [128 129 130 131 132]]]


 [[[  1   2   3   4   5]
   [ 33  34  35  36  37]
   [ 65  66  67  68  69]
   [ 97  98  99 100 101]
   [129 130 131 132 133]]]]
