In [2]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F


In [3]:
def cls_predictor(num_inputs,num_anchors,num_classes):
    return nn.Conv2d(num_inputs,num_anchors*(num_classes + 1),
    kernel_size  = 3, padding=1)


In [4]:
def bbox_predictor(num_inputs,num_anchors):
    return nn.Conv2d(num_inputs,
                    num_anchors*4,
                    kernel_size=3,
                    padding=1)

In [5]:
def forward(x,block):
    return block(x)

y1 = forward(torch.zeros((2,8,20,20)),cls_predictor(8,5,10))
y2 = forward(torch.zeros((2,16,10,10)),cls_predictor(16,3,10))
y1.shape,y2.shape

(torch.Size([2, 55, 20, 20]), torch.Size([2, 33, 10, 10]))

In [6]:
def flatten_pred(pred):
    return torch.flatten(pred.permute(0,2,3,1),start_dim=1)

def concat_preds(preds):
    return torch.cat([flatten_pred(p) for p in preds],dim=1)

concat_preds([y1,y2]).shape

torch.Size([2, 25300])

In [7]:
def down_sample_blk(in_channels,out_channels):
    blk = []
    for _ in range(2):
        blk.append(
            nn.Conv2d(in_channels=in_channels,out_channels=out_channels,
                      kernel_size=3,padding=1)
        )
        blk.append(nn.BatchNorm2d(out_channels))
        blk.append(nn.ReLU())
        in_channels = out_channels
    blk.append(nn.MaxPool2d(2))
    return nn.Sequential(*blk)


forward(torch.zeros((2,3,20,20)),down_sample_blk(3,10)).shape


torch.Size([2, 10, 10, 10])

In [8]:
def base_net():
    blk = []
    num_filters = [3,16,32,64]
    for i in range(len(num_filters)- 1):
        blk.append(down_sample_blk(num_filters[i],num_filters[i+1]))

    return nn.Sequential(*blk)

forward(torch.zeros((2,3,256,256)),base_net()).shape

torch.Size([2, 64, 32, 32])

In [9]:
def multibox_prioe(data,sizes,ratios):
    in_heigjt , in_width = data.shape[-2:]
    device ,num_size ,num_ratios = data.device,len(sizes),len(ratios)
    boxes_per_pixel = (num_size + num_ratios - 1)
    size_tensor = torch.tensor(sizes,device=device)
    ratios_tensor = torch.tensor(ratios,device=device)

    offset_h,offset_w = 0.5,0.5
    #设置中心点偏移

    steps_h = 1.0/in_heigjt
    steps_w = 1.0/in_width
    #计算相对比例
    center_h = (torch.arange(in_heigjt,device=device) + offset_h )*steps_h
    center_w = (torch.arange(in_width,device=device) + offset_w)*steps_w
    #遍历每一个像素点
    shift_y,shift_x = torch.meshgrid(center_h,center_w,indexing='ij')
    #“ij”产生shape为yx的矩阵，先是dim=0全部以y填充，然后填充x次，然后是dim=1
    #上以x填充填充y次
    print("shift_y",shift_y.shape)
    print("shift_x",shift_x.shape)
    shift_y = shift_y.reshape(-1)
    shift_x = shift_x.reshape(-1)
    #拉直
    print(shift_x.shape)

    w = torch.cat((size_tensor * torch.sqrt(ratios_tensor[0]),
                   sizes[0]*torch.sqrt(ratios_tensor[1:])))*in_heigjt /in_width
    #将所有可能的计算情况形成一个矩阵 由于在归于系里面设置比例
    h = torch.cat(
        (size_tensor / torch.sqrt(ratios_tensor[0]),
        size_tensor[0] / torch.sqrt(ratios_tensor[1:]))
    )
    print(w.shape)
    anchor_manipu = torch.stack(
        (-w,-h,w,h)).T.repeat(in_heigjt * in_width,1)/2
    
    #dim=0 stack形成shape为4,5,然后转置为5,4,然后复制m次在dim=0上
    #这个时候 dim=0 的维度为5 * m 其中5说明所有情况为了符合每一个像素点都可以
    #经历5次变化 所以复制m次
    print(anchor_manipu.shape)
    out_grid = torch.stack(
        [shift_x,shift_y,shift_x,shift_y],
        dim=1
    ).repeat_interleave(boxes_per_pixel,dim=0)
    #在dim=1上堆积 shape为m ， 4,然后在dim=0按元素复制一定次数
    #为的是和anchor 对齐  由于按照元素相对于在dim=0上每连续5个是相同的
    #这样刚好和上面对应 实现一个x可以产生5个情况
    #一共有m个x ，则生产m * 5个数据 在dim=0
    output = out_grid + anchor_manipu 
    return output.unsqueeze(0)

In [10]:

def get_blk(i):
    if i==0:
        blk = base_net()
    elif i ==1:
        blk = down_sample_blk(64,128)
    elif i==4:
        blk = nn.AdaptiveMaxPool2d((1,1))
    else:
        blk = down_sample_blk(128,128)
    return blk

def blk_forward(x,blk,size,ratio,cls_predictor,bbox_predictor):
    Y = blk(x)
    anchors = multibox_prioe(Y,sizes=size,ratios=ratio)

    cls_preds = cls_predictor(Y)
    bbox_preds = bbox_predictor(Y)

    return (Y,anchors,cls_preds,bbox_preds)


In [11]:
sizes = [[0.2,0.272],
         [0.37,0.447],
         [0.54,0.619],
         [0.71,0.79],
         [0.88,0.961]]
ratios = [[1,2,0.5]]*5

num_anchors = len(sizes[0]) + len(ratios[0]) -1 


In [14]:
class TinySSD(nn.Module):
    def __init__(self,num_classes, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.num_classes = num_classes

        idx_to_in_channels = [64,128,128,128,128]
        for i in range(5):
            setattr(self,f'blk_{i}',get_blk(i))
            setattr(
                self,
                f'cls_{i}',
                cls_predictor(idx_to_in_channels[i],num_anchors,num_classes)
            )
            setattr(
                self,
                f'bbox_{i}',
                bbox_predictor(idx_to_in_channels[i],num_anchors)
            )
    def forward(self,x):
        anchors , cls_preds,bbox_preds = [None]*5,[None]*5,[None]*5
        for i in range(5):
            x,anchors[i],cls_preds[i],bbox_preds[i] = blk_forward(
                x,
                getattr(self,f'blk_{i}'),
                sizes[i],
                ratios[i],
                getattr(self,f'cls_{i}'),
                getattr(self,f'bbox_{i}')
            )
        anchors = torch.cat(anchors,dim=1)
        cls_preds = concat_preds(cls_preds)
        cls_preds = cls_preds.reshape(cls_preds.shape[0],-1,self.num_classes + 1)
        
        bbox_preds = concat_preds(bbox_preds)

        return anchors,cls_preds,bbox_preds



In [13]:
net = TinySSD(num_classes=1)
X = torch.zeros((32,3,256,256))
anchors , cls_preds , bbox_preds = net(X)
print('output anchors',anchors.shape)
print('output cls_preds',cls_preds.shape)
print('output bbox_preds',bbox_preds.shape)

shift_y torch.Size([32, 32])
shift_x torch.Size([32, 32])
torch.Size([1024])
torch.Size([4])
torch.Size([4096, 4])
shift_y torch.Size([16, 16])
shift_x torch.Size([16, 16])
torch.Size([256])
torch.Size([4])
torch.Size([1024, 4])
shift_y torch.Size([8, 8])
shift_x torch.Size([8, 8])
torch.Size([64])
torch.Size([4])
torch.Size([256, 4])
shift_y torch.Size([4, 4])
shift_x torch.Size([4, 4])
torch.Size([16])
torch.Size([4])
torch.Size([64, 4])
shift_y torch.Size([1, 1])
shift_x torch.Size([1, 1])
torch.Size([1])
torch.Size([4])
torch.Size([4, 4])
output anchors torch.Size([1, 5444, 4])
output cls_preds torch.Size([32, 5444, 2])
output bbox_preds torch.Size([32, 21776])
