In [46]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import models,transforms,datasets
from torch.utils import data
import os
# script_dir = os.path.dirname(os.path.abspath(__file__))
script_dir = os.getcwd()

# 精简输出精度
torch.set_printoptions(4)

def multibox_prior(data, sizes, ratios):
    """生成以每个像素为中心具有不同形状的锚框"""
    # in_height, in_width = data.shape[-2:]
    in_height = data.shape[-2]
    in_width = data.shape[-1]
    print("in_height", in_height)
    print("in_width", in_width)

    device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
    boxes_per_pixel = (num_sizes + num_ratios -1)
    size_tensor = torch.tensor(sizes, device=device)
    ratio_tensor = torch.tensor(ratios, device=device)

    # 为了将锚点移动到像素的中心，需要设置偏移量
    # 因为一个像素的高为1且宽为1，我们选择偏移中心0.5
    offset_h, offset_w = 0.5, 0.5
    steps_h = 1.0 / in_height # 在y轴上缩放步长
    steps_w = 1.0 / in_width # 在x轴上缩放步长

    # 生成锚框的所有中心点
    center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
    print("center_h", center_h)
    center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
    print("center_w", center_w)
    # 输出形状都是n1 x n2，其中第一个输出张量每列填充第一个输入张量，各行元素相同；第二个输出张量每行填充第二个输入张量，各列元素相同
    shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
    print("shift_y", shift_y)
    print("shift_x", shift_x)
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
    print("shift_y", shift_y)
    print("shift_x", shift_x)

    print("size_tensor * torch.sqrt(ratio_tensor[0])", size_tensor * torch.sqrt(ratio_tensor[0]))
    print("sizes[0] * torch.sqrt(ratio_tensor[1:])", sizes[0] * torch.sqrt(ratio_tensor[1:]))
    print(torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]), sizes[0] * torch.sqrt(ratio_tensor[1:]))))
    # 生成“boxes_per_pixel”个高和宽，
    # 之后用于创建锚框的四角坐标(xmin, xmax, ymin, ymax)
    w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]), sizes[0] * torch.sqrt(ratio_tensor[1:]))) * in_height / in_width # 处理矩形输入
    h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]), sizes[0] / torch.sqrt(ratio_tensor[1:])))
    # 除以2来获得半高和半宽
    print(torch.stack((-w, -h, w, h)))
    print(torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1).shape)
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1) / 2

    # 每个中心点都将有"boxes_per_pixel"个锚框
    # 所以生成含所有锚框中心的网络，重复了"boxes_per_pixel"次
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1).repeat_interleave(boxes_per_pixel, dim=0)
    output = out_grid + anchor_manipulations
    print(1, output.shape)
    print(2, output.unsqueeze(0).shape)
    return output.unsqueeze(0)

In [47]:
import matplotlib
matplotlib.use("Agg")  # 这一句一定要放在下面这句的前面
from matplotlib import pyplot as plt

img = plt.imread('../img/catdog.jpg')
h, w = img.shape[:2]

print(h, w)
X = torch.rand(size=(1, 3, 4, 5))
Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])
print(Y.shape)

561 728
in_height 4
in_width 5
center_h tensor([0.1250, 0.3750, 0.6250, 0.8750])
center_w tensor([0.1000, 0.3000, 0.5000, 0.7000, 0.9000])
shift_y tensor([[0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
        [0.3750, 0.3750, 0.3750, 0.3750, 0.3750],
        [0.6250, 0.6250, 0.6250, 0.6250, 0.6250],
        [0.8750, 0.8750, 0.8750, 0.8750, 0.8750]])
shift_x tensor([[0.1000, 0.3000, 0.5000, 0.7000, 0.9000],
        [0.1000, 0.3000, 0.5000, 0.7000, 0.9000],
        [0.1000, 0.3000, 0.5000, 0.7000, 0.9000],
        [0.1000, 0.3000, 0.5000, 0.7000, 0.9000]])
shift_y tensor([0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.3750, 0.3750, 0.3750, 0.3750,
        0.3750, 0.6250, 0.6250, 0.6250, 0.6250, 0.6250, 0.8750, 0.8750, 0.8750,
        0.8750, 0.8750])
shift_x tensor([0.1000, 0.3000, 0.5000, 0.7000, 0.9000, 0.1000, 0.3000, 0.5000, 0.7000,
        0.9000, 0.1000, 0.3000, 0.5000, 0.7000, 0.9000, 0.1000, 0.3000, 0.5000,
        0.7000, 0.9000])
size_tensor * torch.sqrt(ratio_tensor[0]) tensor([0.7500,