In [None]:
import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import dtype2prec_DONTUSE

from depthwise_conv3d import DepthwiseConv3d


class TestConv(TestCase):
    def test_Conv3d_depthwise_naive_groups_cuda(self, dtype=torch.float):
        for depth_multiplier in [1, 2]:
            m = DepthwiseConv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to("cuda", dtype)
            i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_()
            output = m(i)
            grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device="cuda", dtype=dtype) / 2
            output.backward(grad_output)

            offset = 1 * depth_multiplier

            m1 = DepthwiseConv3d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype)
            m1.weight.data = m.weight.data[:offset].clone()
            m1.bias.data = m.bias.data[:offset].clone()
            i1 = i.detach()[:, :1].clone().requires_grad_()
            output1 = m1(i1)
            output1.backward(grad_output[:, :offset].contiguous())

            m2 = DepthwiseConv3d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype)
            m2.weight.data.copy_(m.weight.data[offset:])
            m2.bias.data.copy_(m.bias.data[offset:])
            i2 = i.detach()[:, 1:].clone().requires_grad_()
            output2 = m2(i2)
            output2.backward(grad_output[:, offset:].contiguous())

            self.assertEqual(output, torch.cat([output1, output2], 1),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(i.grad.data,
                             torch.cat([i1.grad.data, i2.grad.data], 1),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.bias.grad.data,
                             torch.cat([m1.bias.grad.data,
                                        m2.bias.grad.data], 0),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)
            self.assertEqual(m.weight.grad.data,
                             torch.cat([m1.weight.grad.data,
                                        m2.weight.grad.data], 0),
                             atol=dtype2prec_DONTUSE[dtype], rtol=0)

In [None]:
test = TestConv()

In [None]:
test.grad_check()

In [1]:
import torch
from depthwise_conv3d import DepthwiseConv3d

dtype = torch.float
# conv = DepthwiseConv3d(64, 64, kernel_size=7, groups=1, padding = 3).to("cuda", dtype)
# input = torch.randn(2, 64, 128, 128, 128, device="cuda", dtype=dtype).div_(2).requires_grad_()
# output = conv(input)
# print(output.shape)

In [5]:
from torch import nn
import math

class diff_net(nn.Module):
    def __init__(self, input_channels):
        super(diff_net, self).__init__()
#         self.A = nn.Parameter(torch.rand((1,1,6,6,6)))
        self.weight = nn.Parameter(torch.Tensor(1,input_channels,128,128,128).cuda())
        self.diff = DepthwiseConv3d(input_channels, input_channels, kernel_size=3, groups=1, padding = 1).to("cuda", dtype)
        self.reset_parameters()
    def reset_parameters(self):
        n = 1
        for k in (3,3,3):
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x):
        
        x = x*(1/self.weight)
        
        x = self.diff(x)
        #print(x)
        return x

In [2]:
from torch import nn

input = torch.ones(1, 1, 6, 6, 6, device="cuda", dtype=dtype).requires_grad_()
input = input.type(torch.cuda.FloatTensor)
diff = DepthwiseConv3d(1, 1, kernel_size=5, groups=1, padding = 2).to("cuda", dtype)
y = diff(input)
loss = (y-y+1).sum()
loss.backward()
print(y)

(5, 5, 5)
torch.Size([5, 5, 5])
torch.Size([1, 1, 5, 5, 5])
torch.Size([1, 1, 5, 5, 5])
torch.Size([1, 1, 5, 5, 5])
tensor([[[[[1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.]],

          [[1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.]],

          [[1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.]],

          [[1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1.,

In [6]:
# 定义模型
model = diff_net(64)
# 将模型中所有参数拷贝到GPU端
# 定义优化器
input = torch.ones(1, 64, 128, 128, 128, device="cuda", dtype=dtype)
target = torch.ones(1, 64, 128, 128, 128, device="cuda", dtype=dtype)
target = target*2
# print(target[0,0,:,2,2].shape)
# target[0,0,:,2:4,2:4] = (torch.ones(6,2,2)*2).cuda()

# target = torch.tensor([[[[[2.8893, 3.6438, 3.6557, 3.6557, 3.6438, 2.8893],
#            [3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438],
#            [3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557],
#            [3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557],
#            [3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438],
#            [2.8893, 3.6438, 3.6557, 3.6557, 3.6438, 2.8893]],

#           [[3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438],
#            [4.5361, 5.5912, 5.6094, 5.6094, 5.5912, 4.5361],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [4.5361, 5.5912, 5.6094, 5.6094, 5.5912, 4.5361],
#            [3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438]],

#           [[3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [4.5656, 5.6278, 5.6465, 5.6465, 5.6278, 4.5656],
#            [4.5656, 5.6278, 5.6465, 5.6465, 5.6278, 4.5656],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557]],

#           [[3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [4.5656, 5.6278, 5.6465, 5.6465, 5.6278, 4.5656],
#            [4.5656, 5.6278, 5.6465, 5.6465, 5.6278, 4.5656],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557]],

#           [[3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438],
#            [4.5361, 5.5912, 5.6094, 5.6094, 5.5912, 4.5361],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [4.5508, 5.6094, 5.6278, 5.6278, 5.6094, 4.5508],
#            [4.5361, 5.5912, 5.6094, 5.6094, 5.5912, 4.5361],
#            [3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438]],

#           [[2.8893, 3.6438, 3.6557, 3.6557, 3.6438, 2.8893],
#            [3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438],
#            [3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557],
#            [3.6557, 4.5508, 4.5656, 4.5656, 4.5508, 3.6557],
#            [3.6438, 4.5361, 4.5508, 4.5508, 4.5361, 3.6438],
#            [2.8893, 3.6438, 3.6557, 3.6557, 3.6438, 2.8893]]]]]).cuda()
print(target)
opt = torch.optim.SGD(model.parameters(), lr=0.000001)
loss = nn.L1Loss()
for epoch in range(100000):
    # 清空优化器缓存
    opt.zero_grad()
    # 前向传播
    output = model(input)
    if epoch == 99999:
        print(model.weight)
    # 求loss
    c = loss(output, target)
    # 反向传播
    c.backward()
    # 更新参数
    opt.step()
    if epoch % 500 == 0:
        print("epoch {:>3d}: loss = {:>8.3f}".format(epoch, c))

(3, 3, 3)
torch.Size([3, 3, 3])
tensor([[[[[2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           ...,
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.]],

          [[2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           ...,
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.]],

          [[2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           ...,
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.]],

          ...,

          [[2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 2.],
           [2., 2., 2.,  ..., 2., 2., 

RuntimeError: Depthwise weight should have in_channels=1, got 64