diff --git a/steganogan/encoders.py b/steganogan/encoders.py index d1a93c8..225cb97 100644 --- a/steganogan/encoders.py +++ b/steganogan/encoders.py @@ -2,7 +2,14 @@ import torch from torch import nn +import torch.onnx +from torchvision.ops.deform_conv import DeformConv2d +input = torch.rand(4, 3, 10, 10) +kh, kw = 3, 3 +weight = torch.rand(5, 3, kh, kw) +offset = torch.rand(4, 2 * kh * kw, 8, 8) +mask = torch.rand(4, kh * kw, 8, 8) class BasicEncoder(nn.Module): """ @@ -16,11 +23,11 @@ class BasicEncoder(nn.Module): add_image = False def _conv2d(self, in_channels, out_channels): - return nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1 + return DeformConv2d( + input=input, + offset, + weight, + mask ) def _build_models(self):