In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary

from random import randint

from polynomial_nets import CP_L3, CP_L3_sparse

from poly_VAE import Flatten, UnFlatten, VAE_CP_L3, VAE_CP_L3_sparse, VAE_CP_L3_sparse_LU, loss_fn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data import random_split

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import pandas as pd

In [3]:
BATCH_SIZE = 64

In [4]:
# Download training dataset
dataset = MNIST(root='data/', download=True)
# MNIST dataset (images and labels)
dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor())

In [5]:
train_ds, val_ds = random_split(dataset, [50000, 10000])


In [6]:
train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True)
#train_loader = DeviceDataLoader(train_loader, device)
#val_loader = DataLoader(val_ds, BATCH_SIZE)

In [7]:
for images,_ in train_loader:
    print(images[0].shape)
    break

torch.Size([1, 28, 28])


In [8]:
test_image = images[0]

In [9]:
test_tensor = torch.randn((1, 28, 28))

In [10]:
kernel_size = 3
stride = 1


In [11]:
test = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])

In [12]:
test_reshape = test.reshape((9))

In [13]:
test_range = torch.arange(0,27).reshape((3,3,3))

https://leimao.github.io/blog/Convolution-Transposed-Convolution-As-Matrix-Multiplication/

In [14]:
def corr2d(X, K):

    # Convolution in deep learning is a misnomer.
    # In fact, it is cross-correlation.
    # https://d2l.ai/chapter_convolutional-neural-networks/conv-layer.html
    # This is equivalent as Conv2D that that input_channel == output_channel == 1 and stride == 1.

    assert X.dim() == 2 and K.dim() == 2

    h, w = K.shape
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()

    return Y

In [15]:
def conv2d_as_matrix_mul(X, K):

    # Assuming no channels and stride == 1.
    # Convert the kernel matrix to sparse matrix (dense matrix with lots of zeros in fact).
    # This is a little bit brain-twisting.

    h_K, w_K = K.shape
    h_X, w_X = X.shape

    h_Y, w_Y = h_X - h_K + 1, w_X - w_K + 1

    W = get_sparse_kernel_matrix(K=K, h_X=h_X, w_X=w_X)

    Y = torch.matmul(W, X.reshape(-1)).reshape(h_Y, w_Y)

    return Y

In [16]:
def conv_transposed_2d_as_matrix_mul(X, K):

    # Assuming no channels and stride == 1.
    # Convert the kernel matrix to sparse matrix (dense matrix with lots of zeros in fact).
    # This is a little bit brain-twisting.

    h_K, w_K = K.shape
    h_X, w_X = X.shape

    h_Y, w_Y = h_X + h_K - 1, w_X + w_K - 1

    # It's like the kernel were applied on the output tensor.
    W = get_sparse_kernel_matrix(K=K, h_X=h_Y, w_X=w_Y)

    # Weight matrix tranposed.
    Y = torch.matmul(W.T, X.reshape(-1)).reshape(h_Y, w_Y)

    return Y

In [17]:
def get_sparse_kernel_matrix(K, h_X, w_X):

    # Assuming no channels and stride == 1.
    # Convert the kernel matrix to sparse matrix (dense matrix with lots of zeros in fact).
    # This is a little bit brain-twisting.

    h_K, w_K = K.shape

    h_Y, w_Y = h_X - h_K + 1, w_X - w_K + 1

    W = torch.zeros((h_Y * w_Y, h_X * w_X))
    for i in range(h_Y):
        for j in range(w_Y):
            for ii in range(h_K):
                for jj in range(w_K):
                    W[i * w_Y + j, i * w_X + j + ii * w_X + jj] = K[ii, jj]

    return W

In [18]:
X = torch.arange(28*28).reshape(28, 28).float()
K = torch.arange(9).reshape(3, 3).float()


In [19]:
h_K, w_K = K.shape 
h_X, w_X = X.shape

In [20]:
W = get_sparse_kernel_matrix(K, h_X, w_X)

In [21]:
Y = torch.matmul(W, X.reshape(-1))

In [22]:
X_1 = torch.matmul(W.T, Y)

In [23]:
X_r = X.reshape(-1)

In [24]:
W_shape = W.shape

In [25]:
test_index = torch.tensor([[0, 1], [1, 0]])

In [26]:
target = torch.zeros([5,3])
indices = torch.LongTensor([[0,1], [1, 2], [2, 2], [3, 0], [4, 1]])
value = torch.ones(indices.shape[0])
target.index_put_(tuple(indices.t()), value)

tensor([[0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]])

In [27]:
indices.t()

tensor([[0, 1, 2, 3, 4],
        [1, 2, 2, 0, 1]])

In [28]:
zero_matrix = torch.zeros(W_shape)

In [29]:
indices = torch.nonzero(W)

In [30]:
matrix = torch.randn(W_shape)

In [31]:
#res = matrix[indices[:,[0]], indices[:,[1]]]
#res = matrix[test_index]
#res = W[indices[:,[0]], indices[:,[1]]]
res1 = W[indices[:,0], indices[:,1]]

In [32]:
zero_matrix.index_put_(tuple(indices.t()), res1)

tensor([[0., 1., 2.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 8., 0., 0.],
        [0., 0., 0.,  ..., 7., 8., 0.],
        [0., 0., 0.,  ..., 6., 7., 8.]])

In [33]:
torch.max(W -  zero_matrix)

tensor(0.)

In [34]:
kernel_width = 4
kernel_height = kernel_width
image_width = 16
image_height = image_width
stride = 2
width_steps = int((image_width - kernel_width)/stride)
height_steps = int((image_height - kernel_height)/stride)


In [35]:
base = torch.arange(0, kernel_width)
base_new = base.repeat(kernel_height)
addition = torch.arange(0, kernel_width) * image_width
addition_new  = addition.repeat_interleave(kernel_height)
index_1 = addition_new + base_new

In [36]:
index_width = index_1.repeat(width_steps + 1)
addition_width = 2 * torch.arange(0, width_steps + 1)
addition_width = addition_width.repeat_interleave(kernel_height * kernel_width)
index_row = index_width + addition_width

In [37]:
index_column = index_row.repeat(height_steps + 1)
addition_height = 2 * image_width * torch.arange(0, height_steps + 1)
addition_height_rep = addition_height.repeat_interleave(kernel_height * kernel_width * (height_steps + 1))
index_final = index_column + addition_height_rep

In [38]:
stack = torch.arange(0, (height_steps + 1) * (width_steps + 1)).repeat_interleave(kernel_height * kernel_width)
indices = torch.stack((stack, index_final), dim=1)

In [39]:
values = torch.randn(indices.shape[0])
values1 = torch.nn.Linear(indices.shape[0], 1).weight.to(torch.float64)[0]

In [40]:
W_in = image_height * image_width
W_out = (height_steps + 1) * (width_steps + 1)
W = torch.zeros([W_out, W_in], dtype=torch.float64)
W = W.index_put_(tuple(indices.t()), values1)

W = torch.zeros([W_out, W_in], dtype=torch.float32)
values = torch.nn.Linear(indices.shape[0], 1).weight.to(torch.float32)[0]
W = W.index_put_(tuple(indices.t()), values)

mask = torch.zeros([W_out, W_in], dtype=torch.float64)
values_m = torch.ones_like(values)
mask = mask.index_put_(tuple(indices.t()), values_m)

RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.

In [41]:
W.dtype
values1.dtype

torch.float64

In [42]:
kernel_height * kernel_width * (height_steps + 1) * (width_steps + 1)

784

In [43]:
def weight_matrix_k4s2(image_size):
    
    kernel_width = 4
    kernel_height = kernel_width
    image_width = image_size
    image_height = image_width
    stride = 2
    width_steps = int((image_width - kernel_width)/stride)
    height_steps = int((image_height - kernel_height)/stride)

    base = torch.arange(0, kernel_width)
    base_new = base.repeat(kernel_height)
    addition = torch.arange(0, kernel_width) * image_width
    addition_new  = addition.repeat_interleave(kernel_height)
    index_1 = addition_new + base_new

    index_width = index_1.repeat(width_steps + 1)
    addition_width = stride * torch.arange(0, width_steps + 1)
    addition_width = addition_width.repeat_interleave(kernel_height * kernel_width)
    index_row = index_width + addition_width

    index_column = index_row.repeat(height_steps + 1)
    addition_height = stride * image_width * torch.arange(0, height_steps + 1)
    addition_height_rep = addition_height.repeat_interleave(kernel_height * kernel_width * (height_steps + 1))
    index_final = index_column + addition_height_rep

    stack = torch.arange(0, (height_steps + 1) * (width_steps + 1)).repeat_interleave(kernel_height * kernel_width)
    indices = torch.stack((stack, index_final), dim=1)

    W_in = image_height * image_width
    W_out = (height_steps + 1) * (width_steps + 1)
    W = torch.zeros([W_out, W_in], dtype=torch.float64)
    values = torch.nn.Linear(indices.shape[0], 1).weight.to(torch.float64)[0]
    W = W.index_put_(tuple(indices.t()), values)

    return W

In [49]:
W1 = weight_matrix_k4s2(32)
W2 = weight_matrix_k4s2(32)
W3 = torch.vstack((W1, W2))

In [53]:
W_tup = [W1]*5

In [55]:
[weight_matrix_k4s2(d)[0]]*6

tensor([[-0.0013, -0.0071,  0.0122,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0031,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0159,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0165, -0.0135, -0.0124]],
       dtype=torch.float64, grad_fn=<IndexPutBackward0>)

In [46]:
W = weight_matrix_k4s2(32).T

In [47]:
W.shape[1]

225

In [48]:
torch.sqrt(torch.tensor(32*32))

tensor(32.)