# Implement IWRU-net for Watermark Removal

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Define the architecture

In [2]:
class UNet(nn.Module):
    KERNEL_SIZE = 3

    def __init__(self, unet_channels_in, unet_channels_out):
        super(UNet, self).__init__()

        # encoder
        self.conv_relu_1 = self.conv_relu(unet_channels_in, 48)
        self.conv_relu_pooling_2 = self.conv_relu_pooling(48, 48)
        self.conv_relu_pooling_3 = self.conv_relu_pooling(48, 48)
        self.conv_relu_pooling_4 = self.conv_relu_pooling(48, 48)
        self.conv_relu_pooling_5 = self.conv_relu_pooling(48, 48)
        self.conv_relu_pooling_6 = self.conv_relu_pooling(48, 48)

        # bottleneck
        self.conv_relu_pooling_7 = self.conv_relu_pooling(48, 48)

        # decoder
        self.conv_relu_convt_8 = self.conv_relu_convt(48, 48)
        self.conv_relu_9 = self.conv_relu(96, 96)
        self.conv_relu_convt10 = self.conv_relu_convt(96, 96)
        self.conv_relu11 = self.conv_relu(96, 96)  # the dimensions are not mentioned in the paper
        self.conv_relu_convt12 = self.conv_relu_convt(96, 96)
        self.conv_relu13 = self.conv_relu(144, 96)
        self.conv_relu_convt14 = self.conv_relu_convt(96, 96)
        self.conv_relu15 = self.conv_relu(144, 96)
        self.conv_relu_convt16 = self.conv_relu_convt(96, 96)
        self.conv_relu17 = self.conv_relu(144, 96)
        self.conv_relu_convt18 = self.conv_relu_convt(96, 96)
        self.conv_relu19 = self.conv_relu(96, 64)
        self.conv_relu20 = self.conv_relu(64, 32)
        self.conv_leaky_relu21 = self.conv_leaky_relu(32, unet_channels_out)

    def conv_relu(self, channels_in, channels_out):
        return nn.Sequential(
            nn.Conv2d(
                channels_in,
                channels_out,
                kernel_size=UNet.KERNEL_SIZE,
            ),
            nn.ReLU(inplace=True),
        )

    def conv_leaky_relu(self, channels_in, channels_out):
        return nn.Sequential(
            nn.Conv2d(
                channels_in,
                channels_out,
                kernel_size=UNet.KERNEL_SIZE,
            ),
            nn.LeakyReLU(inplace=True),
        )

    def conv_relu_pooling(self, channels_in, channels_out):
        return nn.Sequential(
            nn.Conv2d(
                channels_in,
                channels_out,
                kernel_size=UNet.KERNEL_SIZE,
            ),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(UNet.KERNEL_SIZE),
        )

    def conv_relu_convt(self, channels_in, channels_out):
        return nn.Sequential(
            nn.Conv2d(
                channels_in,
                channels_out,
                kernel_size=UNet.KERNEL_SIZE,
            ),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(
                channels_in,
                channels_out,
                kernel_size=UNet.KERNEL_SIZE,
            ),
        )

    def forward(self, x):
        # encoder
        enc1 = self.conv_relu_1(x)
        enc2 = self.conv_relu_pooling_2(enc1)
        enc3 = self.conv_relu_pooling_3(enc2)
        enc4 = self.conv_relu_pooling_4(enc3)
        enc5 = self.conv_relu_pooling_5(enc4)
        enc6 = self.conv_relu_pooling_6(enc5)

        # bottleneck
        bottleneck7 = self.conv_relu_pooling_7(enc6)

        # decoder
        decoder8 = self.conv_relu_convt_8(
            torch.cat(
                [
                    bottleneck7,
                    F.interpolate(
                        enc6,
                        bottleneck7.size()[2:],
                        mode="bilinear",
                        align_corners=True,
                    ),
                ],
                dim=1,
            )
        )
        decoder9 = self.conv_relu_9(decoder8)
        decoder10 = self.conv_relu_convt10(
            torch.cat(
                [
                    decoder9,
                    F.interpolate(
                        enc5, decoder9.size()[2:], mode="bilinear", align_corners=True
                    ),
                ],
                dim=1,
            )
        )
        decoder11 = self.conv_relu11(decoder10)
        decoder12 = self.self.conv_relu_convt12(
            torch.cat(
                [
                    decoder11,
                    F.interpolate(
                        enc4, decoder11.size()[2:], mode="bilinear", align_corners=True
                    ),
                ],
                dim=1,
            )
        )
        decoder13 = self.conv_relu13(decoder12)
        decoder14 = self.conv_relu_convt14(
            torch.cat(
                [
                    decoder13,
                    F.interpolate(
                        enc3, decoder13.size()[2:], mode="bilinear", align_corners=True
                    ),
                ],
                dim=1,
            )
        )
        decoder15 = self.self.conv_relu15(decoder14)
        decoder16 = self.conv_relu_convt16(
            torch.cat(
                [
                    decoder15,
                    F.interpolate(
                        enc2, decoder15.size()[2:], mode="bilinear", align_corners=True
                    ),
                ],
                dim=1,
            )
        )
        decoder17 = self.conv_relu17(decoder16)
        decoder18 = self.conv_relu_convt18(
            torch.cat(
                [
                    decoder17,
                    F.interpolate(
                        enc1, decoder17.size()[2:], mode="bilinear", align_corners=True
                    ),
                ],
                dim=1,
            )
        )
        decoder19 = self.conv_relu19(decoder18)
        decoder20 = self.conv_relu20(decoder19)

        return self.conv_leaky_relu21(decoder20)


In [3]:
unet = UNet(3, 3)
total_params = sum(layer.numel() for layer in unet.parameters())
print(f"The model has {total_params} parameters")

The model has 1612323 parameters
