### Транспонированная 2D свертка

In [1]:
import numpy as np
class ConvTranspose2D():
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, dtype=None):
        
        self.weights = None
        self.bias_weights = None
        
        self.dtype = dtype
        
        if isinstance(in_channels, int) and in_channels > 0:
            self.in_channels = in_channels
        else:
            raise ValueError("Invalid in_channels")
        
        if isinstance(out_channels, int) and out_channels > 0:
            self.out_channels = out_channels
        else:
            raise ValueError("Invalid out_channels")
        
        if isinstance(groups, int) and groups > 0:
            self.groups = groups
        else:
            raise ValueError("Invalid groups")
        
        if isinstance(bias, int) or isinstance(bias, bool):
            self.bias = bool(bias)
        else:
            raise ValueError("Invalid bias")
        
        if isinstance(kernel_size, tuple):
            try:
                self.kernel1, self.kernel2 = kernel_size
                if not (isinstance(self.kernel1, int) and isinstance(self.kernel2, int)):
                    raise ValueError("Invalid kernel_size types")
            except ValueError:
                raise ValueError("Invalid tuple format for kernel_size")
        elif isinstance(kernel_size, int) and kernel_size > 0:
            self.kernel1 = self.kernel2 = kernel_size
        else:
            raise ValueError("Invalid kernel_size")
        
        if isinstance(stride, tuple):
            try:
                self.stride1, self.stride2 = stride
                if not (isinstance(self.stride1, int) and isinstance(self.stride2, int)):
                    raise ValueError("Invalid stride types")
            except ValueError:
                raise ValueError("Invalid tuple format for stride")
        elif isinstance(stride, int) and stride > 0:
            self.stride1 = self.stride2 = stride
        else:
            raise ValueError("Invalid stride")
        
        if isinstance(padding, tuple):
            try:
                self.padding1, self.padding2 = padding
                if not (isinstance(self.padding1, int) and isinstance(self.padding2, int)):
                    raise ValueError("Invalid padding types")
            except ValueError:
                raise ValueError("Invalid tuple format for padding")
        elif isinstance(padding, int) and padding > -1:
            self.padding1 = self.padding2 = padding
        else:
            raise ValueError("Invalid padding")
            
        if isinstance(output_padding, tuple):
            try:
                self.output_padding1, self.output_padding2 = padding
                if not (isinstance(self.output_padding1, int) and isinstance(self.output_padding2, int)):
                    raise ValueError("Invalid output_padding types")
            except ValueError:
                raise ValueError("Invalid tuple format for output_padding")
        elif isinstance(output_padding, int) and output_padding > -1:
            self.output_padding1 = self.output_padding2 = output_padding
        else:
            raise ValueError("Invalid output_padding")
            
        if isinstance(dilation, tuple):
            try:
                self.dilation1, self.dilation2 = dilation
                if not (isinstance(self.dilation1, int) and isinstance(self.dilation2, int)):
                    raise ValueError("Invalid dilation types")
            except ValueError:
                raise ValueError("Invalid tuple format for dilation")
        elif isinstance(dilation, int) and dilation > 0:
            self.dilation1 = self.dilation2 = dilation
        else:
            raise ValueError("Invalid dilation")
        if not((self.in_channels % self.groups == 0) and (self.out_channels % self.groups == 0)):
            raise ValueError("in_channels and out_channels must both be divisible by groups")
        
        if (self.output_padding1 >= self.dilation1 and self.output_padding1 >= self.stride1) or (self.output_padding2 >= self.dilation2 and self.output_padding2 >= self.stride2):
            raise ValueError("output_padding should be smaller than dilation or stride")
            
    def set_weights(self, weights, bias = None):
        if self.bias == True and (type(bias) == type(np.array([]))):
            if len(bias.shape) == 1 and bias.shape[0] == self.out_channels:
                self.bias_weights = bias
            else:
                raise TypeError("Invalid bias weights shape")
        if self.bias == True and (type(bias) != type(np.array([]))):
            raise TypeError("Invalid bias weights")
        
        if type(weights) != type(np.array([])):
            raise TypeError("Invalid weights")
        if len(weights.shape) != 4:
            raise ValueError("Invalid weights shape")
        if weights.shape[0] != self.in_channels:
            raise ValueError(f"Incorrect axis=0 weights dimension, given {weights.shape[0]}, expected {self.in_channels}")
        if weights.shape[1] != self.out_channels // self.groups:
            raise ValueError(f"Incorrect axis=1 weights dimension, given {weights.shape[1]}, expected {self.out_channels // self.groups}")
        if weights.shape[2] != self.kernel1:
            raise ValueError(f"Incorrect axis=2 weights dimension, given {weights.shape[2]}, expected {self.kernel1}")
        if weights.shape[3] != self.kernel2:
            raise ValueError(f"Incorrect axis=3 weights dimension, given {weights.shape[3]}, expected {self.kernel2}")
        self.weights = weights
        
    
    def __get_conv(self, channels, h_in, w_in, offset=0, bias_offset=0):
        size1 = (h_in - 1) * self.stride1 + self.dilation1 * (self.kernel1 - 1) + 1
        size2 = (w_in - 1) * self.stride2 + self.dilation2 * (self.kernel2 - 1) + 1
        conv = []
        for i in range(len(channels)):
            ch_list = []
            ch = channels[i]
            for b, k in enumerate(self.weights[i+offset]):
                f_map = np.zeros((size1, size2))
                x, y = 0, 0
                for m in range(ch.shape[0]):
                    for n in range(ch.shape[1]):
                        x, y = n * self.stride2, m * self.stride1
                        d = ch[m][n]
                        t_conv = k * d
                        y_m = y + self.kernel1 + (self.kernel1 - 1) * (self.dilation1 - 1)
                        x_m = x + self.kernel2 + (self.kernel2 - 1) * (self.dilation2 - 1)
                        #print(t_conv)
                        #print(t_conv.shape)
                        #print(x, x_m)
                        #print(y, y_m)
                        f_map[y:y_m:self.dilation1, x:x_m:self.dilation2] += t_conv
            
                for _ in range(self.output_padding1):
                    f_map = np.vstack((f_map, np.zeros((1, f_map.shape[1]))))
                for _ in range(self.output_padding2):
                    f_map = np.hstack((f_map, np.zeros((f_map.shape[0], 1))))
                for _ in range(self.padding1):
                    f_map = f_map[1:-1]
                for _ in range(self.padding2):
                    f_map = f_map[:, 1:-1]
                if self.bias:
                    f_map += self.bias_weights[b + bias_offset] / len(channels)
                ch_list.append(f_map)
            conv.append(ch_list)
        conv = np.asarray(conv)
        conv = np.sum(conv, axis=0)
        return conv
    
    def forward(self, tensor):
        if len(tensor.shape) == 3:
            tensor = np.expand_dims(tensor, axis=0)
        if len(tensor.shape) != 4:
            raise ValueError(f"Invalid tensor dimensions = {len(tensor.shape)}, expected 3 or 4")
        N, c_in, h_in, w_in = tensor.shape
                
        c_out = self.out_channels
        h_out = int((h_in - 1) * self.stride1 - 2 * self.padding1 + self.dilation1 * (self.kernel1 - 1) + self.output_padding1 + 1)
        w_out = int((w_in - 1) * self.stride2 - 2 * self.padding2 + self.dilation2 * (self.kernel2 - 1) + self.output_padding2 + 1)
        try:
            out_tensor = np.zeros((N, c_out, h_out, w_out), dtype=self.dtype)
        except:
            raise TypeError("Invalid dtype")
        
        for n in range(N):
            cur_out_channel = 0
            step_for_bias = self.out_channels // self.groups
            step_for_channels = c_in // self.groups
            ch_pos = 0
            bias_pos = 0
            for i in range(self.groups):
                current_channels = tensor[n, ch_pos:ch_pos+step_for_channels]
                conv = self.__get_conv(current_channels, h_in, w_in, offset=ch_pos, bias_offset=bias_pos)
                ch_pos += step_for_channels
                bias_pos += step_for_bias
                for c in conv:
                    out_tensor[n, cur_out_channel] = c
                    cur_out_channel += 1
        return out_tensor

In [2]:
import torch
from tqdm import tqdm

In [3]:
rng = np.random.default_rng()

In [4]:
from tqdm import tqdm
def test():
    for i in (pbar := tqdm(range(10))):
        t = rng.integers(low=0, high=255, size=(1, in_channels, 10, 10)) / 1.0
        tt = torch.Tensor(t)
        t_conv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
                                 bias=bias, groups=groups, stride=stride,
                                 padding=padding, output_padding=output_padding, dilation=dilation)

        # берем случайные веса из слоя torch и записываем их в наш кастомный класс
        weights = t_conv.weight.detach().numpy()
        if bias:
            bias_w = t_conv.bias.detach().numpy()
        layer = ConvTranspose2D(in_channels, out_channels, kernel_size,
                                 bias=bias, groups=groups, stride=stride,
                                 padding=padding, output_padding=output_padding, dilation=dilation)
        if bias:
            layer.set_weights(weights, bias_w)
        else:
            layer.set_weights(weights)
        torch_result = t_conv(tt).detach().numpy()
        layer_result = layer.forward(t)
        
        assert np.allclose(torch_result, layer_result, atol=0.0001), "Error"
    print('pass')

In [5]:
# первая группа
kernel_size = (3, 3)
padding = 1
dilation = 1
stride = 1
in_channels = 8
out_channels = 20
bias = True
groups = 2
output_padding = 0

In [6]:
test()

100%|██████████| 10/10 [00:00<00:00, 16.24it/s]

pass





In [7]:
# вторая группа
kernel_size = (4, 2)
padding = 1
dilation = 1
stride = 1
in_channels = 8
out_channels = 20
bias = True
groups = 2
output_padding = 0

In [8]:
test()

100%|██████████| 10/10 [00:00<00:00, 16.76it/s]

pass





In [9]:
# третья группа
kernel_size = (2, 4)
padding = 0
dilation = 1
stride = 2
in_channels = 4
out_channels = 16
bias = True
groups = 4
output_padding = 1

In [10]:
test()

100%|██████████| 10/10 [00:00<00:00, 75.76it/s]

pass





In [11]:
# четвертая группа
kernel_size = (2, 4)
padding = 0
dilation = 3
stride = 2
in_channels = 8
out_channels = 20
bias = True
groups = 2
output_padding = 1

In [12]:
test()

100%|██████████| 10/10 [00:00<00:00, 17.15it/s]

pass





In [13]:
# пятая группа
kernel_size = (5, 3)
padding = 2
dilation = 2
stride = (3, 1)
in_channels = 16
out_channels = 8
bias = True
groups = 1
output_padding = 0

In [14]:
test()

100%|██████████| 10/10 [00:00<00:00, 10.94it/s]

pass





### Реализация транспонированной свертки через обычную

Т.к. формат весов для обычной и транспонированной свертки слегка отличается, а именно, для Conv2D это (out_channels, input_channels/groups, kernel_size[0], kernel_size[1]), а для ConvTranspose2D это (in_channels, out_channels/groups, kernel_size[0], kernel_size[1]), то для задания одинаковых весов слегка изменим логику работы обычного класса Conv2D.

Также следует отметить, что при реализации транспонированной свертки через обычную теряется возможность управления некоторыми параметрами, а именно padding, output_padding и dilation. Несмотря на то, что существуют некоторые редкие "специфичные" библиотечные реализации, позволяющие задавать эти параметры их реализация, очевидно, выходит за рамки рассматриваемой темы, поэтому единственными параметрами, которыми можно будет управлять это in_channels, out_channels, kernel_size, stride и groups.

In [15]:
# меняем класс под нормальную логику работы с groups
import numpy as np
class Conv2D():
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, output_padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros', dtype=None):
        
        self.weights = None
        self.bias_weights = None
        
        self.dtype = dtype
        
        padding_modes = ['zeros', 'replicate']
        if padding_mode not in padding_modes:
            raise ValueError("Invalid padding_mode")
        self.padding_mode = padding_mode
        
        if isinstance(in_channels, int) and in_channels > 0:
            self.in_channels = in_channels
        else:
            raise ValueError("Invalid in_channels")
        
        if isinstance(out_channels, int) and out_channels > 0:
            self.out_channels = out_channels
        else:
            raise ValueError("Invalid out_channels")
        
        if isinstance(groups, int) and groups > 0:
            self.groups = groups
        else:
            raise ValueError("Invalid groups")
        
        if isinstance(bias, int) or isinstance(bias, bool):
            self.bias = bool(bias)
        else:
            raise ValueError("Invalid bias")
        
        if isinstance(kernel_size, tuple):
            try:
                self.kernel1, self.kernel2 = kernel_size
                if not (isinstance(self.kernel1, int) and isinstance(self.kernel2, int)):
                    raise ValueError("Invalid kernel_size types")
            except ValueError:
                raise ValueError("Invalid tuple format for kernel_size")
        elif isinstance(kernel_size, int) and kernel_size > 0:
            self.kernel1 = self.kernel2 = kernel_size
        else:
            raise ValueError("Invalid kernel_size")
        
        if isinstance(stride, tuple):
            try:
                self.stride1, self.stride2 = stride
                if not (isinstance(self.stride1, int) and isinstance(self.stride2, int)):
                    raise ValueError("Invalid stride types")
            except ValueError:
                raise ValueError("Invalid tuple format for stride")
        elif isinstance(stride, int) and stride > 0:
            self.stride1 = self.stride2 = stride
        else:
            raise ValueError("Invalid stride")
        
        if isinstance(padding, str):
            if padding in ["valid", "same"]:
                if padding == 'same' and self.stride1 != 1:
                    raise ValueError("padding 'same' is not valid for stride > 1")
                self.padding1 = self.padding2 = padding
            else:
                raise ValueError("Invalid padding")
        elif isinstance(padding, tuple):
            try:
                self.padding1, self.padding2 = padding
                if not (isinstance(self.padding1, int) and isinstance(self.padding2, int)):
                    raise ValueError("Invalid padding types")
            except ValueError:
                raise ValueError("Invalid tuple format for padding")
        elif isinstance(padding, int) and padding > -1:
            self.padding1 = self.padding2 = padding
        else:
            raise ValueError("Invalid padding")
            
        if isinstance(dilation, tuple):
            try:
                self.dilation1, self.dilation2 = dilation
                if not (isinstance(self.dilation1, int) and isinstance(self.dilation2, int)):
                    raise ValueError("Invalid dilation types")
            except ValueError:
                raise ValueError("Invalid tuple format for dilation")
        elif isinstance(dilation, int) and dilation > 0:
            self.dilation1 = self.dilation2 = dilation
        else:
            raise ValueError("Invalid dilation")
        if not((self.in_channels % self.groups == 0) and (self.out_channels % self.groups == 0)):
            raise ValueError("in_channels and out_channels must both be divisible by groups")
            
    def set_weights(self, weights, bias = None):
        if self.bias == True and (type(bias) == type(np.array([]))):
            if len(bias.shape) == 1 and bias.shape[0] == self.out_channels:
                self.bias_weights = bias
            else:
                raise TypeError("Invalid bias weights shape")
        if self.bias == True and (type(bias) != type(np.array([]))):
            raise TypeError("Invalid bias weights")
        
        if type(weights) != type(np.array([])):
            raise TypeError("Invalid weights")
        if len(weights.shape) != 4:
            raise ValueError("Invalid weights shape")
        if weights.shape[0] != self.in_channels:
            raise ValueError(f"Incorrect axis=0 weights dimension, given {weights.shape[0]}, expected {self.in_channels}")
        if weights.shape[1] != self.out_channels // self.groups:
            raise ValueError(f"Incorrect axis=1 weights dimension, given {weights.shape[1]}, expected {self.out_channels // self.groups}")
        if weights.shape[2] != self.kernel1:
            raise ValueError(f"Incorrect axis=2 weights dimension, given {weights.shape[2]}, expected {self.kernel1}")
        if weights.shape[3] != self.kernel2:
            raise ValueError(f"Incorrect axis=3 weights dimension, given {weights.shape[3]}, expected {self.kernel2}")
        self.weights = weights
        
    
    def __get_conv(self, channels, h_out, w_out, offset=0, bias_offset=0):
        k1_m = self.kernel1 + (self.kernel1 - 1) * (self.dilation1 - 1)
        k2_m = self.kernel2 + (self.kernel2 - 1) * (self.dilation2 - 1)
        conv = []
        
        for i in range(len(channels)):
            ch_list = []
            ch = channels[i]
            for h in range(self.padding1):
                if (self.padding_mode == 'zeros'):
                    ch = np.vstack((ch, np.zeros(ch.shape[1])))
                    ch = np.vstack((np.zeros(ch.shape[1]), ch))
                elif (self.padding_mode == 'replicate'):
                    ch = np.vstack((ch, np.array(ch[-1])))
                    ch = np.vstack((np.array(ch[0]), ch))

            for w in range(self.padding2):
                if (self.padding_mode == 'zeros'):
                    ch = np.hstack((ch, np.zeros((ch.shape[0], 1))))
                    ch = np.hstack((np.zeros((ch.shape[0], 1)), ch))
                elif (self.padding_mode == 'replicate'):
                    ch = np.hstack((ch, np.expand_dims(np.array(ch[:, -1]), axis=0).T))
                    ch = np.hstack((np.expand_dims(np.array(ch[:, 0]), axis=0).T, ch))

            if ch.shape[0] < k1_m or ch.shape[1] < k2_m:
                raise RuntimeError(f"Channel shape which is {ch.shape} smaller than calculated kernel {k1_m, k2_m}")
            
            for b, k in enumerate(self.weights[i+offset]):
                out = np.zeros((h_out, w_out), dtype=self.dtype)
                m_ind, n_ind = 0, 0
                x, y = 0, 0
                m, n = out.shape
                while m_ind < m:
                    x = 0
                    n_ind = 0
                    while n_ind < n:
                        out[m_ind, n_ind] += np.sum(ch[y:k1_m+y:self.dilation1, x:k2_m+x:self.dilation2] * k)
                        n_ind += 1
                        x += self.stride2
                    y += self.stride1
                    m_ind += 1
                if self.bias:
                    out += self.bias_weights[b+bias_offset] / len(channels)
                
                ch_list.append(out)
            
            conv.append(ch_list)
        
        conv = np.asarray(conv)
        conv = np.sum(conv, axis=0)    
        
        return conv    
    
    
    def forward(self, tensor):
        if len(tensor.shape) == 3:
            tensor = np.expand_dims(tensor, axis=0)
        if len(tensor.shape) != 4:
            raise ValueError(f"Invalid tensor dimensions = {len(tensor.shape)}, expected 3 or 4")
        N, c_in, h_in, w_in = tensor.shape
        k1_m = self.kernel1 + (self.kernel1 - 1) * (self.dilation1 - 1)
        k2_m = self.kernel2 + (self.kernel2 - 1) * (self.dilation2 - 1)
        if self.padding1 == 'valid':
            self.padding1 = self.padding2 = 0
        if self.padding1 == 'same':
            self.padding1 = int((h_in * self.stride1 - self.stride1 - h_in + self.dilation1 * (self.kernel1 - 1) + 1) / 2)
            self.padding2 = int((w_in * self.stride2 - self.stride2 - w_in + self.dilation2 * (self.kernel2 - 1) + 1) / 2)
                
        c_out = self.out_channels
        h_out = int((h_in + 2 * self.padding1 - self.dilation1 * (self.kernel1 - 1) - 1) / self.stride1) + 1
        w_out = int((w_in + 2 * self.padding2 - self.dilation2 * (self.kernel2 - 1) - 1) / self.stride2) + 1
        try:
            out_tensor = np.zeros((N, c_out, h_out, w_out), dtype=self.dtype)
        except:
            raise TypeError("Invalid dtype")
        
        for n in range(N):
            cur_out_channel = 0
            step_for_bias = self.out_channels // self.groups
            step_for_channels = c_in // self.groups
            ch_pos = 0
            bias_pos = 0
            for i in range(self.groups):
                current_channels = tensor[n, ch_pos:ch_pos+step_for_channels]
                conv = self.__get_conv(current_channels, h_out, w_out, offset=ch_pos, bias_offset=bias_pos)
                ch_pos += step_for_channels
                bias_pos += step_for_bias
                for c in conv:
                    out_tensor[n, cur_out_channel] = c
                    cur_out_channel += 1
        return out_tensor

In [16]:
def preprocess_tensor(tensor, stride, kernel_size):
    '''
    Данная функция выполняет преобразование исходного тензора, для возможности применения обычной 2D свертки,
    с целью получить результат транспонированной свертки
    
    tensor - np.array, axis_count = 4;
    stride - tuple;
    kernel_size - tuple;
    
    return out_tensor - np.array, axis_count = 4
    '''
    stride_h, stride_w = stride
    kernel_h, kernel_w = kernel_size
    N, c_in, h_in, w_in = tensor.shape
    h_out = 2 * (kernel_h - 1) + h_in + (h_in - 1) * (stride_h - 1)
    w_out = 2 * (kernel_w - 1) + w_in + (w_in - 1) * (stride_w - 1)
    out_tensor = np.zeros((N, c_in, h_out, w_out))
    for n in range(N):
        for c in range(c_in):
            m = tensor[n, c] # входная мапа
            y, x = kernel_h - 1, kernel_w - 1 # левый верхний угол для заполнения
            y_m, x_m = h_out - kernel_h + 1, w_out - kernel_w + 1 # правый нижний угол
            out_tensor[n, c, y:y_m:stride_h, x:x_m:stride_w] = m
    return out_tensor

In [17]:
from tqdm import tqdm
def test():
    for i in (pbar := tqdm(range(10))):
        t = rng.integers(low=0, high=255, size=(5, in_channels, 10, 10)) / 1.0
        tt = torch.Tensor(t)
        t_conv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
                                 bias=bias, groups=groups, stride=stride)

        # берем случайные веса из слоя torch и записываем их в наш кастомный класс
        weights = t_conv.weight.detach().numpy()
        if bias:
            bias_w = t_conv.bias.detach().numpy()
        layer = Conv2D(in_channels, out_channels, kernel_size,
                                 bias=bias, groups=groups)
        if bias:
            layer.set_weights(np.flip(weights, axis=(3, 2)), bias_w)
        else:
            layer.set_weights(np.flip(weights, axis=(3, 2)))
        torch_result = t_conv(tt).detach().numpy()
        layer_result = layer.forward(preprocess_tensor(t, stride, kernel_size))
        
        assert np.allclose(torch_result, layer_result, atol=0.001), "Error"
    print('pass')

In [18]:
# первая группа
kernel_size = (3, 3)
stride = (1, 1)
in_channels = 16
out_channels = 8
bias = True
groups = 2

In [19]:
test()

100%|██████████| 10/10 [00:03<00:00,  2.81it/s]

pass





In [20]:
# вторая группа
kernel_size = (2, 4)
stride = (4, 2)
in_channels = 4
out_channels = 20
bias = True
groups = 4

In [21]:
test()

100%|██████████| 10/10 [00:06<00:00,  1.62it/s]

pass





In [22]:
# третья группа
kernel_size = (3, 2)
stride = (5, 3)
in_channels = 20
out_channels = 60
bias = True
groups = 5

In [23]:
test()

100%|██████████| 10/10 [01:57<00:00, 11.76s/it]

pass





In [24]:
# четвертая группа
kernel_size = (7, 7)
stride = (2, 2)
in_channels = 3
out_channels = 32
bias = True
groups = 1

In [25]:
test()

100%|██████████| 10/10 [00:21<00:00,  2.18s/it]

pass



