In [1]:
import torch
import torch.nn.functional as F
import numpy as np

In [41]:
def bicubic_weight_function(x):
    a = -0.75
    x = torch.abs(x)
    # 权重计算公式
    weight = ((a + 2.0) * torch.pow(x, 3) - (a + 3.0) * torch.pow(x, 2) + 1.0) * (x <= 1.0) + \
             (a * torch.pow(x, 3) - 5.0 * a * torch.pow(x, 2) + 8.0 * a * x - 4.0 * a) * ((x > 1.0) & (x <= 2.0))

         
    return weight.repeat(4,1)

In [42]:
if __name__ == '__main__':
    image = torch.tensor([[[[0,1],
                            [2,3]]]],dtype=torch.float32)
    out_height,out_width = 3,3

    _,_,height,width = image.shape
    # 计算高度宽度缩放因子
    scale_y = height / out_height
    scale_x = width / out_width

    # 填充
    image_pad = F.pad(image,(2,2,2,2),mode='replicate')

    # 创建输出图像
    output_image = np.zeros((out_height,out_width),dtype=np.float32)


    for out_y in range(out_height):
        for out_x in range(out_width):
            x = (out_x + 0.5) * scale_x + 1.5
            y = (out_y + 0.5) * scale_y + 1.5
            delta_x = x % 1
            delta_y = y % 1
            # 计算距离
            distance_x = torch.tensor([delta_x+1,delta_x,1-delta_x,2-delta_x])
            distance_y = torch.tensor([delta_y+1,delta_y,1-delta_y,2-delta_y])

            # x,y轴方向权重
            weight_x = bicubic_weight_function(distance_x)
            weight_y = bicubic_weight_function(distance_y).T

            index_x = round(x + 0.5)
            index_y = round(y + 0.5)
            source = image_pad[:,:,index_y-2:index_y+2,index_x-2:index_x+2].squeeze()

            output_image[out_y,out_x] = torch.multiply(torch.multiply(source,weight_x),weight_y).sum()
    print(output_image)
            

[[-0.26041672  0.32638872  0.913194  ]
 [ 0.9131942   1.5         2.0868053 ]
 [ 2.0868049   2.673611    3.2604156 ]]


In [43]:
result = F.interpolate(image,(out_height,out_width),mode = 'bicubic',align_corners=False)
print(result)

tensor([[[[-0.2604,  0.3264,  0.9132],
          [ 0.9132,  1.5000,  2.0868],
          [ 2.0868,  2.6736,  3.2604]]]])
