In [1]:
import torch
from torch import nn

In [15]:
class UNet(nn.Module):

    def __init__(self, in_ch, num_classes):
        super().__init__()

        self.dconv_down1 = self._double_conv(in_ch, 64)
        self.dconv_down2 = self._double_conv(64, 128)
        self.dconv_down3 = self._double_conv(128, 256)
        self.dconv_down4 = self._double_conv(256, 512)
        self.dconv_down5 = self._double_conv(512, 1024)

        self.maxpool = nn.MaxPool2d(2)

        self.upconv4 = self._up_conv(1024, 512)
        self.upconv3 = self._up_conv(512, 256)
        self.upconv2 = self._up_conv(256, 128)
        self.upconv1 = self._up_conv(128, 64)

        self.dconv_up4 = self._double_conv(1024, 512)
        self.dconv_up3 = self._double_conv(512, 256)
        self.dconv_up2 = self._double_conv(256, 128)
        self.dconv_up1 = self._double_conv(128, 64)

        self.conv_last = nn.Conv2d(64, num_classes, 1)

    def _double_conv(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU()
        )

    def _up_conv(self, in_ch, out_ch):
        return nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)

    def forward(self, X):
        conv1 = self.dconv_down1(X)
        X = self.maxpool(conv1)

        conv2 = self.dconv_down2(X)
        X = self.maxpool(conv2)

        conv3 = self.dconv_down3(X)
        X = self.maxpool(conv3)

        conv4 = self.dconv_down4(X)
        X = self.maxpool(conv4)

        X = self.dconv_down5(X)

        X = self.upconv4(X)
        X = self.dconv_up4(torch.cat([X, conv4], dim=1))
        X = self.upconv3(X)
        X = self.dconv_up3(torch.cat([X, conv3], dim=1))
        X = self.upconv2(X)
        X = self.dconv_up2(torch.cat([X, conv2], dim=1))
        X = self.upconv1(X)
        X = self.dconv_up1(torch.cat([X, conv1], dim=1))

        out = self.conv_last(X)
        return out


In [19]:
# 入力は2^nである必要がある（256, 512, etc.）
X = torch.randn(1, 3, 256, 256)
model = UNet(3, 10)
outputs = model(X)

In [20]:
outputs.shape

torch.Size([1, 10, 256, 256])

In [21]:
outputs

tensor([[[[ 0.0773,  0.0790,  0.0799,  ...,  0.0811,  0.0785,  0.0803],
          [ 0.0732,  0.0744,  0.0804,  ...,  0.0803,  0.0779,  0.0819],
          [ 0.0751,  0.0737,  0.0746,  ...,  0.0736,  0.0807,  0.0789],
          ...,
          [ 0.0766,  0.0795,  0.0815,  ...,  0.0662,  0.0783,  0.0834],
          [ 0.0770,  0.0846,  0.0820,  ...,  0.0861,  0.0812,  0.0821],
          [ 0.0844,  0.0868,  0.0816,  ...,  0.0880,  0.0872,  0.0865]],

         [[-0.0628, -0.0658, -0.0643,  ..., -0.0659, -0.0593, -0.0611],
          [-0.0625, -0.0747, -0.0773,  ..., -0.0799, -0.0703, -0.0683],
          [-0.0583, -0.0780, -0.0764,  ..., -0.0593, -0.0719, -0.0632],
          ...,
          [-0.0613, -0.0844, -0.0749,  ..., -0.0717, -0.0757, -0.0611],
          [-0.0691, -0.0788, -0.0640,  ..., -0.0750, -0.0743, -0.0717],
          [-0.0645, -0.0689, -0.0617,  ..., -0.0703, -0.0640, -0.0594]],

         [[ 0.1088,  0.1108,  0.0977,  ...,  0.1009,  0.1009,  0.1049],
          [ 0.0982,  0.0991,  