In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, in_channels=9, out_channels=1, base_ch=32):
        super().__init__()
        self.enc1 = DoubleConv(in_channels, base_ch)
        self.pool = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(base_ch, base_ch*2)
        self.enc3 = DoubleConv(base_ch*2, base_ch*4)
        self.enc4 = DoubleConv(base_ch*4, base_ch*8)

        self.bottleneck = DoubleConv(base_ch*8, base_ch*16)

        self.up4 = nn.ConvTranspose2d(base_ch*16, base_ch*8, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(base_ch*16, base_ch*8)
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(base_ch*2, base_ch)

        self.final_conv = nn.Conv2d(base_ch, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.final_conv(d1)
        return out

if __name__ == "__main__":
    x = torch.randn(1, 9, 128, 128)
    model = UNet(in_channels=9, out_channels=1, base_ch=32)
    y = model(x)

    print(f"Input shape : {x.shape}")
    print(f"Output shape: {y.shape}")

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters    : {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x): 
        return self.net(x)

class UNet_MultiRes(nn.Module):
    def __init__(self, coarse_ch=7, fine_ch=1, out_ch=1, base_ch=32):
        super().__init__()

        self.coarse_up = nn.Sequential(
            nn.Conv2d(coarse_ch, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.unet = UNet(in_channels=32 + fine_ch, out_channels=out_ch, base_ch=base_ch)

    def forward(self, coarse_x, fine_x):
        coarse_up = F.interpolate(
            coarse_x,
            size=fine_x.shape[-2:], 
            mode="bilinear",
            align_corners=False
        )
        coarse_up = self.coarse_up(coarse_up)

        x = torch.cat([coarse_up, fine_x], dim=1)

        return self.unet(x)

if __name__ == "__main__":
    coarse_x = torch.randn(1, 7, 16, 16)
    fine_x = torch.randn(1, 1, 128, 128)

    model = UNet_MultiRes(coarse_ch=7, fine_ch=1, out_ch=1, base_ch=32)
    y = model(coarse_x, fine_x)

    print(f"Coarse input : {coarse_x.shape}")
    print(f"Fine input   : {fine_x.shape}")
    print(f"Output shape : {y.shape}")

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")

Input shape : torch.Size([1, 9, 128, 128])
Output shape: torch.Size([1, 1, 128, 128])
Total parameters    : 7,764,769
Trainable parameters: 7,764,769
Coarse input : torch.Size([1, 7, 16, 16])
Fine input   : torch.Size([1, 1, 128, 128])
Output shape : torch.Size([1, 1, 128, 128])
Total parameters: 7,782,977
