In [1]:
#KPConv

import torch
import math

class KPConv(torch.nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, radius, sigma, bias = False, dimension = 3, inf = 1e6, epsilon = 1e-9):
    """ 
    Args:  
    in_channels : dimension of input features
    out_channels : dimension of out features
    kernel_size : number of kernel points
    radius : radius used for kernel point init
    sigma : influence radius of each kernel point
    bias : use bias or not (default: False)
    dimension : dimension of the point space
    inf : value of infinity to generate the padding point
    epsilon: epsilon for gaussian influence

    """

    super(KPConv,self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.radius = radius
    self.sigma = sigma

    self.dimension = dimension
    self.inf = inf
    self.epsilon = epsilon

    #initialize weights

    self.weights = torch.nn.Parameter(torch.zeros(self.kernel_size,self.in_channels,self.out_channels))

    if bias:
      self.bias = torch.nn.Parameter(torch.zeros(self.out_channels))
    else:
      self.register_parameter('bias',None)
           
    self.reset_parameters()



    def reset_parameters(self):
      torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))

      if self.bias is not None:
        fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in) 
        torch.nn.init.uniform(self.bias,-bound,bound)
    
    def initialize_kernel_points(self):
      




Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

