In [1]:
import torch
from torch import nn
import torchvision

In [47]:
def apply_weight_init(param, weight_style: str, num_anchors: int, num_classes: int): 
    assert isinstance(param, torch.nn.modules.conv.Conv2d)
    PRIOR_PROB = torch.tensor(0.01)
    SIGMA = torch.tensor(0.01)

    if weight_style == "default":
        # Seee SSD paper
        nn.init.xavier_uniform_(param.weight)
        # the bias is zero I assume

    elif weight_style == "classification":
        # Initialize all classes to have bias of 0
        custom_bias = torch.zeros((num_classes, 1), dtype=torch.float32)
        """ Focal   Method: torch.log((num_classes - 1) * (1 - PRIOR_PROB) / (PRIOR_PROB))"""
        """ TDT4265 Method: torch.log(PRIOR_PROB * ((num_classes - 1) / (1 - PRIOR_PROB)))"""
        # Except for background class:
        custom_bias[0] = torch.log(PRIOR_PROB * ((num_classes - 1) / (1 - PRIOR_PROB)))   
        # Repeat foreach anchor
        custom_bias = torch.vstack([custom_bias]*num_anchors)
        # Update weights and biases
        torch.nn.init.normal_(param.weight, std=SIGMA.item())
        param.bias.data = custom_bias.squeeze()

    elif weight_style == "regression":
        # See Focal Loss paper
        torch.nn.init.normal_(param.weight, std=SIGMA.item())
        torch.nn.init.zeros_(param.bias)
    else:
        raise NotImplementedError(f"Unknown weight_style: {weight_style}")

In [48]:
conv = nn.Conv2d(256, 6*9, kernel_size=3, stride=1, padding=1)
num_classes = 8+1
num_anchors = 6

apply_weight_init(conv, "classification", num_anchors, num_classes)

Bias: torch.Size([54])
Bias: torch.Size([54, 1])
Bias: torch.Size([54])


In [2]:
from neuralvision.ssd.retinanet import create_subnet

In [3]:
from torchvision.models.detection import retinanet_resnet50_fpn

In [4]:
mod = retinanet_resnet50_fpn(num_classes=2, pretrained=False)

In [9]:
conv = nn.Conv2d(256, 6*9, kernel_size=3, stride=1, padding=1)

In [22]:
indexes = [numb for numb in range(6*9) if numb % 9 == 0]

In [23]:
indexes

[0, 9, 18, 27, 36, 45]

In [16]:
map = {
    0: 'background',
    1: 'car',
    2: 'truck',
    3: 'bus',
    4: 'motorcycle',
    5: 'bicycle',
    6: 'scooter',
    7: 'person',
    8: 'rider',
    9: "background",
    10: 'car',
    11: 'truck',
    12: 'bus',
    13: 'motorcycle',
    14: 'bicycle',
    15: 'scooter',
    16: 'person',
    17: 'rider',
    18: "background",
}


In [26]:
import math

In [32]:
background_pos = 0
n_positions = conv.bias.shape[0]
for K in range(n_positions):
    if K % 9 == 0:
        idx = K+1
        p = 0.99
        bias = math.log(p * ((idx-1) / (1-p)))
        print(f"K: {K}, bias: {bias}")


ValueError: math domain error

In [None]:
from torchvision.models.detection import retinanet_resnet50_fpn