-
Notifications
You must be signed in to change notification settings - Fork 1
/
nets_BASE.py
89 lines (65 loc) · 2.76 KB
/
nets_BASE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn.functional as F
#from torchsummary import summary
from torch import nn
from modules import DilatedResidualBlock, NLB, MTUR, DepthWiseDilatedResidualBlock
class basic(nn.Module):
def __init__(self, num_features=64):
super(basic, self).__init__()
self.mean = torch.zeros(1, 3, 1, 1)
self.std = torch.zeros(1, 3, 1, 1)
self.mean[0, 0, 0, 0] = 0.485
self.mean[0, 1, 0, 0] = 0.456
self.mean[0, 2, 0, 0] = 0.406
self.std[0, 0, 0, 0] = 0.229
self.std[0, 1, 0, 0] = 0.224
self.std[0, 2, 0, 0] = 0.225
self.mean = nn.Parameter(self.mean)
self.std = nn.Parameter(self.std)
self.mean.requires_grad = False
self.std.requires_grad = False
self.head = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride = 2 ,padding=1), nn.ReLU(),
nn.Conv2d(32, num_features, kernel_size=1, stride=1, padding=0), nn.ReLU()
)
self.body = nn.Sequential(
DilatedResidualBlock(num_features, 1),
DilatedResidualBlock(num_features, 1),
DilatedResidualBlock(num_features, 2),
DilatedResidualBlock(num_features, 2),
DilatedResidualBlock(num_features, 4),
DilatedResidualBlock(num_features, 8),
DilatedResidualBlock(num_features, 4),
DilatedResidualBlock(num_features, 2),
DilatedResidualBlock(num_features, 2),
DilatedResidualBlock(num_features, 1),
DilatedResidualBlock(num_features, 1)
)
self.tail = nn.Sequential(
#nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.ReLU(),
nn.ConvTranspose2d(num_features, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(32, 3, kernel_size=3, padding=1)
)
self.f_process = nn.Sequential(
nn.ConvTranspose2d(num_features, 64 , kernel_size=4, stride=2, padding=1),
nn.GroupNorm(num_groups=64, num_channels=64),
nn.SELU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.output = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(),
# nn.ConvTranspose2d(num_features, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1)
)
for m in self.modules():
if isinstance(m, nn.ReLU):
m.inplace = True
def forward(self, x):
x = (x - self.mean) / self.std
f = self.head(x)
f = self.body(f)
f = self.f_process(f)
x = self.output(f)
x = (x * self.std + self.mean).clamp(min=0, max=1)
return x