In [307]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable

In [308]:
class Conv1D():
    def __init__(self,in_channel,out_channel,kernel_size,stride=1, W_init=None, b_init=None):
        
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.W = np.random.rand(out_channel,in_channel,kernel_size)
        self.b = np.zeros(out_channel)
        self.l = 0
        self.X = None
        
        self.dW, self.db = np.zeros(self.W.shape), np.zeros(self.b.shape)
        
        if W_init:
            self.W = W_init
        if b_init:
            self.b = b_init
    def __call__(self,x):
        return self.forward(x)
    
    def forward(self,x):
        self.l = x.shape[2]
        self.X = x
        start = 0
        end = self.kernel_size
        while end <= x.shape[2]:
            segment = x[:,:,start:end]

            if start == 0:
                out = np.tensordot(segment,self.W.T,axes=([1,2],[1,0])) + np.repeat(np.expand_dims(self.b,axis=0),segment.shape[0],axis = 0)
                out = np.expand_dims(out,axis = 2)
            else:
                out = np.concatenate((out,np.expand_dims(np.tensordot(segment,self.W.T,axes=([1,2],[1,0])) + np.repeat(np.expand_dims(self.b,axis=0),segment.shape[0],axis = 0),axis = 2)),axis = 2)
            start += self.stride
            end += self.stride
            
        
        return out
    
    def backward(self,delta):
        batch = delta.shape[0]
        dy = np.zeros((batch,self.in_channel,self.l))
                       
        start = 0
        end = self.kernel_size
        for i in range(delta.shape[2]):
            dy[:,:,start:end] += np.tensordot(delta[:,:,i]-np.repeat(np.expand_dims(self.b,axis=0),batch,axis = 0),self.W,axes=([1],[0]))
            self.dW += np.tensordot(delta[:,:,i]-np.repeat(np.expand_dims(self.b,axis=0),batch,axis = 0),self.X[:,:,start:end],axes=([0],[0]))

            
            start += self.stride
            end += self.stride  
            
        self.db = np.sum(np.sum(delta,axis = 2),axis = 0)
        
        return dy
        

In [309]:
def compare(x,y):
    print(abs(x-y.detach().numpy()).max())
    return

In [310]:

net1 = Conv1D(8, 12, 3, 2)
net2 = torch.nn.Conv1d(8, 12, 3, 2)

In [311]:
x1 = np.random.rand(10, 8, 20)
x2 = Variable(torch.tensor(x1),requires_grad=True)

In [312]:
net2.weight = nn.Parameter(torch.tensor(net1.W))
net2.bias = nn.Parameter(torch.tensor(net1.b))

In [313]:
y1 = net1(x1)
b, c, w = y1.shape

In [314]:
delta = np.random.randn(b,c,w)
dx = net1.backward(delta)

In [315]:
y2 = net2(x2)
delta = torch.tensor(delta)
y2.backward(delta)

In [316]:
compare(y1, y2)
compare(dx, x2.grad)
compare(net1.dW, net2.weight.grad)
compare(net1.db, net2.bias.grad)

3.552713678800501e-15
8.881784197001252e-16
3.552713678800501e-15
3.552713678800501e-15
