# biFPN

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
c0 = torch.randn([1, 2, 64, 64])
c1 = torch.randn([1, 4, 32, 32])
c2 = torch.randn([1, 8, 16, 16])
epsilon = 0.0001
W_bifpn = 3
input = [c0, c1, c2]

In [3]:
# P2 intermediate
P2_td  = nn.Conv2d(8, W_bifpn, kernel_size=3,
                   stride=1, bias=True, padding=1)(input[2])

# P2 intermediate_upsample
p2_upsample = nn.Upsample(scale_factor=2, mode='nearest')(P2_td)

# P1 intermediate_input
P1_td_inp = nn.Conv2d(4, W_bifpn, kernel_size=3,
                      stride=1, bias=True, padding=1)(input[1])

# P1 intermediate
# Attention weights: Fast normalized fusion
p1_td_w = torch.rand(2, dtype=torch.float, requires_grad=True)
P1_td = nn.Conv2d(W_bifpn, W_bifpn, kernel_size=3,
                  stride=1, bias=True, padding=1)((p1_td_w[0] * P1_td_inp + 
                                                   p1_td_w[1] * p2_upsample) /
                                                  (p1_td_w[0] + p1_td_w[1] + epsilon))
P1_td = nn.ReLU()(P1_td)
P1_td = nn.BatchNorm2d(W_bifpn)(P1_td)

# P1 intermediate_upsample
p1_upsample = nn.Upsample(scale_factor=2, mode='nearest')(P1_td)

# P0 intermediate
P0_td  = nn.Conv2d(2, W_bifpn, kernel_size=3,
                   stride=1, bias=True, padding=1)(input[0])

# P0 output
# Attention weights: Fast normalized fusion
p0_out_w = torch.rand(2, dtype=torch.float, requires_grad=True)
P0_out = nn.Conv2d(W_bifpn, W_bifpn, kernel_size=3,
                   stride=1, bias=True, padding=1)((p0_out_w[0] * P0_td +
                                                    p0_out_w[1] * p1_upsample) /
                                                   (p0_out_w[0] + p0_out_w[1] + epsilon))
P0_out = nn.ReLU()(P0_out)
P0_out = nn.BatchNorm2d(W_bifpn)(P0_out)

# P0 output_downsample
p0_downsample = nn.MaxPool2d(kernel_size=2)(P0_out)

# P1 output
# Attention weights: Fast normalized fusion
p1_out_w = torch.rand(3, dtype=torch.float, requires_grad=True)
P1_out = nn.Conv2d(W_bifpn, W_bifpn, kernel_size=3,
                   stride=1, bias=True, padding=1)((p1_out_w[0] * P1_td_inp + 
                                                    p1_out_w[1] * P1_td +
                                                    p1_out_w[2] * p0_downsample) /
                                                   (p1_out_w[0] + p1_out_w[1] + p1_out_w[2] + epsilon))
P1_out = nn.ReLU()(P1_out)
P1_out = nn.BatchNorm2d(W_bifpn)(P1_out)

# P1 output_downsample
P1_out_downsample = nn.MaxPool2d(kernel_size=2)(P1_out)

# P2 output
# Attention weights: Fast normalized fusion
p2_out_w = torch.rand(2, dtype=torch.float, requires_grad=True)
P2_out = nn.Conv2d(W_bifpn, W_bifpn, kernel_size=3,
                   stride=1, bias=True, padding=1)((p2_out_w[0] * P2_td + 
                                                    p2_out_w[1] * P1_out_downsample) /
                                                   (p1_out_w[0] + p1_out_w[1] + epsilon))
P2_out = nn.ReLU()(P2_out)
P2_out = nn.BatchNorm2d(W_bifpn)(P2_out)

# Final output
output = [P0_out, P1_out, P2_out]
output[0].shape, output[1].shape, output[2].shape

(torch.Size([1, 3, 64, 64]),
 torch.Size([1, 3, 32, 32]),
 torch.Size([1, 3, 16, 16]))