# The Deep Learning Homework 3 - Question No. 14  
Bahar Mahdavi - SN: 40152521337

### We assume the equivalent weight of a convolution layer to be W, which is $W \in \mathbb{R}^{N\times M \times L \times K}$. If we consider this tensor as the product of four one-dimensional vectors in low rank (t), so we have $Wi,j,l,k = a^i\times b^i\times c^i\times d^i$

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
def VecProduct(a, b, c, d):
    dimensions = [a.shape[0], b.shape[0], c.shape[0], d.shape[0]]
    w = torch.tensor(a[i] * b[j] * c[l] * d[k] 
                    for k in range(dimensions[3])
                    for l in range(dimensions[2])
                    for j in range(dimensions[1])
                    for i in range(dimensions[0]))
    return w

In [3]:
def weights(rank, dimN, dimM, dimL, dimK):
    a = torch.rand(rank, dimN).requires_grad_(True)
    b = torch.rand(rank, dimM).requires_grad_(True)
    c = torch.rand(rank, dimL).requires_grad_(True)
    d = torch.rand(rank, dimK).requires_grad_(True)

    W = torch.zeros(dimN, dimM, dimL, dimK)
    for t in range(rank):
        W += torch.einsum('i,j,k,l->ijkl', a[t], b[t], c[t], d[t])
    return W

In [4]:
class ConvolutionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, rank, dim):
        
        super(ConvolutionLayer, self).__init__()
        self.dimN, self.dimM, self.dimL, self.dimK = dim
        if self.dimN != out_channels:
            raise ValueError(f'{self.dimN} is not equal to {out_channels}')
        if self.dimM != in_channels:
            raise ValueError(f'{self.dimM} is not equal to {in_channels}')
        self.rank = rank
        
    def forward(self, z):
        return F.conv2d(input = z, weight = weights(self.rank, self.dimN, self.dimM, self.dimL, self.dimK))

### Test the model 

In [5]:
z = torch.randn(1, 3, 5, 5)  # input
W = torch.randn(2, 3, 2, 2)  # weight 

In [6]:
conv_layer = ConvolutionLayer(3, 64, 2, (64,3,4,4))
C = conv_layer(z)  # output

In [7]:
print("input:")
print(z)
print("wights:")
print(W)
print("output:")
print(C)

input:
tensor([[[[-0.3097, -1.1636,  0.6096, -0.4294,  0.3548],
          [ 1.1298,  0.2220,  0.7470,  0.1938, -1.1761],
          [-0.6657, -0.2547,  1.6188, -0.5290,  0.5796],
          [-0.3994, -0.8395,  0.2908,  0.6088,  0.6565],
          [-0.0356,  0.1512,  0.8814, -0.7965, -0.2351]],

         [[ 0.3865,  0.8598,  0.9623,  1.3150,  0.0598],
          [ 2.0278,  2.4628,  0.6752,  1.2212, -0.0390],
          [ 1.4419, -2.7861, -0.2695, -0.9984,  1.5148],
          [ 0.8156, -1.2264,  0.6983, -1.2612, -0.4677],
          [ 0.4803,  0.7049, -0.5394,  0.7509, -0.9374]],

         [[ 0.3382,  0.4085, -1.7307,  0.3687,  0.7655],
          [ 0.2558,  1.2210,  0.9032,  1.5146,  0.1163],
          [-0.1457, -0.7474, -1.0655,  0.8627,  0.6507],
          [ 1.1998,  0.2671,  1.6609, -0.2052, -0.3969],
          [-0.3965,  0.8824,  0.9324, -2.0103, -0.5503]]]])
wights:
tensor([[[[-1.3735,  2.1461],
          [-0.0203,  1.5852]],

         [[ 0.1329, -0.3977],
          [ 1.2068, -0.0409]],
