In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.models.archs.arch_util import LayerNorm2d
from basicsr.models.archs.local_arch import Local_Base
# from torchinfo import summary
import numpy as np
# from scipy.linalg import hadamard
from hadamard_transform import hadamard_transform

In [None]:
def find_min_power(x, p=2):
    y = 1
    while y<x:
        y *= p
    return y

In [None]:
class SoftThresholding(torch.nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.num_features = num_features
        self.T = torch.nn.Parameter(torch.rand(self.num_features)/10)
              
    def forward(self, x):
        print(x.shape,self.T.shape)
#         return torch.mul(torch.sign(x), torch.nn.functional.relu(torch.abs(x)-torch.abs(self.T)))
        return torch.mul(torch.tanh(x), torch.nn.functional.relu(torch.abs(x)-torch.abs(self.T)))

In [None]:
def hadamard_transform(u, axis=-1, fast=True):
    """Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n.
    n must be a power of 2.
    Parameters:
        u: Tensor of shape (..., n)
        normalize: if True, divide the result by 2^{m/2} where m = log_2(n).
    Returns:
        product: Tensor of shape (..., n)
    """  
    if axis != -1:
        u = torch.transpose(u, -1, axis)
    
    n = u.shape[-1]
    m = int(np.log2(n))
    assert n == 1 << m, 'n must be a power of 2'
    if fast:
        x = u[..., np.newaxis]
        for d in range(m)[::-1]:
            x = torch.cat((x[..., ::2, :] + x[..., 1::2, :], x[..., ::2, :] - x[..., 1::2, :]), dim=-1)
        y = x.squeeze(-2) / 2**(m / 2)
    else:
        H = torch.tensor(hadamard(n), dtype=torch.float, device=u.device)
        y = u @ H.t()/np.sqrt(n)
    if axis != -1:
        y = torch.transpose(y, -1, axis)
        
    return y

In [None]:
class WHT_expansion(torch.nn.Module):
    """
    通道数扩展
    num_features: Length of the last axis, should be interger power of 2. If not, we pad 0s.
    residual: Apply shortcut connection or not
    retain_DC: Retain DC channel (the first channel) or not
    """
    def __init__(self, input_features , output_features , residual=False , retain_DC=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.num_features_pad = find_min_power(self.output_features)  
        self.ST = SoftThresholding(self.num_features_pad)    
        self.residual = residual
        self.retain_DC = retain_DC

         
    def forward(self, x):
        input_features = x.shape[-1]
        if input_features!= self.input_features:
            raise Exception('{}!={}'.format(input_features, self.input_features))
        if self.num_features_pad>input_features:
            f0 = torch.nn.functional.pad(x, (0, self.num_features_pad-input_features))
        else:
            f0 = x
        f1 = hadamard_transform(f0)

#         f2 = self.v*f1
        f3 = self.ST(f1)
        # 如果需要，添加直流分量
        if self.retain_DC:
            f3[..., 0] = f1[..., 0]  # 恢复直流分量
        f4 = hadamard_transform(f3)
        y = f4[..., :self.output_features]
        if self.residual:
            y = y + x
        return y

In [None]:
class WHT_projection(torch.nn.Module):
    """
    通道数减少 
    num_features: Length of the last axis, should be interger power of 2. If not, we pad 0s.
    residual: Apply shortcut connection or not
    retain_DC: Retain DC channel (the first channel) or not
    """
    def __init__(self, input_features , output_features , residual=False , retain_DC=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.input_features_pad = find_min_power(self.input_features)  
        self.output_features_pad = find_min_power(self.output_features)  
        self.r = int(2**(self.input_features_pad - self.output_features_pad))
        self.ST = SoftThresholding(int(2**self.input_features_pad - self.r + 1))    
        
        self.residual = residual
        self.retain_DC = retain_DC
         
    def forward(self, x):
        input_features = x.shape[-1]
        if input_features!= self.input_features:
            raise Exception('{}!={}'.format(input_features, self.input_features))
        if self.input_features_pad>input_features:
            f0 = torch.nn.functional.pad(x, (0, self.input_features_pad-input_features))
        else:
            f0 = x
        f1 = hadamard_transform(f0)
            
        # 计算要平均池化的通道范围
        start_channel = 1  # 从通道1开始
        end_channel = int(2 ** self.input_features_pad - self.r + 1)  # 计算结束通道
        print(end_channel)
        # 选择要进行平均池化的通道
        f2 = f1[:, :, :, start_channel:end_channel+1]
#         print(f2.shape)
        # 对选定通道进行平均池化
        n = f2.shape[0]
        f3 = []
        for i in range(n):
            f3.append(F.avg_pool1d(f2[i], kernel_size=self.r, stride=self.r , padding=0))
        f3 = torch.stack(f3)
        f4 = torch.cat(f1[:,:,:,0]/self.r , f3 , dim=-1)
        f5 = hadamard_transform(f4)
        y = f4[..., :self.output_features]
        if self.residual:
            y = y + x
        return y

In [None]:
import torch
from torch import nn
img=torch.arange(8*4*4).reshape(1,8,4,4)
img_t = img.permute(0,3,2,1)
# # 池化核和池化步长均为2
pool=nn.AvgPool1d(2,stride=2)
img_2=pool(img_t[0])
n = img_t.shape[0]
temp = []
for i in range(n):
    temp.append(F.avg_pool1d(img_t[i], kernel_size=2 , stride=2))
temp = torch.stack(temp)

print(img,img.shape)
print(img_t,img_t.shape)
print(img_2,img_2.shape)
print(temp , temp.shape)

In [None]:
net = WHT_expansion(input_features=4 , output_features=8)
ex_ten = net(img_t)
print(ex_ten.shape)

In [None]:
net1 = WHT_projection(input_features=8 , output_features=2)
ex_ten1 = net1(img_t)
print(ex_ten1.shape)