-
Notifications
You must be signed in to change notification settings - Fork 121
/
base.py
136 lines (110 loc) · 4.65 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
from torchvision.models.resnet import resnet18
class Up(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor=2):
super().__init__()
self.up = nn.Upsample(scale_factor=scale_factor, mode='bilinear',
align_corners=True)
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2):
x1 = self.up(x1)
x1 = torch.cat([x2, x1], dim=1)
return self.conv(x1)
class CamEncode(nn.Module):
def __init__(self, C):
super(CamEncode, self).__init__()
self.C = C
self.trunk = EfficientNet.from_pretrained("efficientnet-b0")
self.up1 = Up(320+112, self.C)
def get_eff_depth(self, x):
# adapted from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py#L231
endpoints = dict()
# Stem
x = self.trunk._swish(self.trunk._bn0(self.trunk._conv_stem(x)))
prev_x = x
# Blocks
for idx, block in enumerate(self.trunk._blocks):
drop_connect_rate = self.trunk._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.trunk._blocks) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
if prev_x.size(2) > x.size(2):
endpoints['reduction_{}'.format(len(endpoints)+1)] = prev_x
prev_x = x
# Head
endpoints['reduction_{}'.format(len(endpoints)+1)] = x
x = self.up1(endpoints['reduction_5'], endpoints['reduction_4'])
return x
def forward(self, x):
return self.get_eff_depth(x)
class BevEncode(nn.Module):
def __init__(self, inC, outC, instance_seg=True, embedded_dim=16, direction_pred=True, direction_dim=37):
super(BevEncode, self).__init__()
trunk = resnet18(pretrained=False, zero_init_residual=True)
self.conv1 = nn.Conv2d(inC, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = trunk.bn1
self.relu = trunk.relu
self.layer1 = trunk.layer1
self.layer2 = trunk.layer2
self.layer3 = trunk.layer3
self.up1 = Up(64 + 256, 256, scale_factor=4)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=True),
nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, outC, kernel_size=1, padding=0),
)
self.instance_seg = instance_seg
if instance_seg:
self.up1_embedded = Up(64 + 256, 256, scale_factor=4)
self.up2_embedded = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=True),
nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, embedded_dim, kernel_size=1, padding=0),
)
self.direction_pred = direction_pred
if direction_pred:
self.up1_direction = Up(64 + 256, 256, scale_factor=4)
self.up2_direction = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=True),
nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, direction_dim, kernel_size=1, padding=0),
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x1 = self.layer1(x)
x = self.layer2(x1)
x2 = self.layer3(x)
x = self.up1(x2, x1)
x = self.up2(x)
if self.instance_seg:
x_embedded = self.up1_embedded(x2, x1)
x_embedded = self.up2_embedded(x_embedded)
else:
x_embedded = None
if self.direction_pred:
x_direction = self.up1_embedded(x2, x1)
x_direction = self.up2_direction(x_direction)
else:
x_direction = None
return x, x_embedded, x_direction