In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
x = torch.randn(1, 1, 28, 28)
t = torch.randn(1, 10)

In [4]:
t = t[:, :, None, None]
t = t.expand(-1, -1, 28, 28)

In [5]:
x+t

tensor([[[[-2.1204, -1.9200, -0.8335,  ..., -3.0485, -1.1710, -1.8532],
          [-4.4841, -0.3954, -4.5222,  ..., -1.3759, -0.2037, -1.6805],
          [-0.5427, -1.2588, -1.9890,  ..., -1.8176, -2.5087, -1.5693],
          ...,
          [-1.6859,  0.2088, -3.3992,  ..., -3.4385, -2.7556, -2.9025],
          [ 0.2439,  0.1677, -2.8887,  ..., -2.3116, -3.9316, -2.1073],
          [-2.3236, -1.0756, -1.2229,  ..., -3.7300, -0.6277, -0.7658]],

         [[-1.3499, -1.1495, -0.0629,  ..., -2.2780, -0.4005, -1.0827],
          [-3.7136,  0.3751, -3.7517,  ..., -0.6054,  0.5668, -0.9100],
          [ 0.2278, -0.4883, -1.2185,  ..., -1.0471, -1.7382, -0.7988],
          ...,
          [-0.9154,  0.9793, -2.6287,  ..., -2.6680, -1.9851, -2.1320],
          [ 1.0144,  0.9382, -2.1182,  ..., -1.5411, -3.1611, -1.3368],
          [-1.5531, -0.3051, -0.4524,  ..., -2.9594,  0.1428,  0.0047]],

         [[-1.2268, -1.0264,  0.0602,  ..., -2.1549, -0.2774, -0.9596],
          [-3.5905,  0.4982, -

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x):
        x = self.maxpool(x)
        x = self.conv(x)
        return x

class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x = self.up(x1)
        diffY = x2.size()[2] - x.size()[2]
        diffX = x2.size()[3] - x.size()[3]
        x = F.pad(x, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x], dim=1)
        x = self.conv(x)
        return x

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.down1 = DoubleConv(in_channels, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.bottleneck = DoubleConv(256, 512)
        self.up1 = Up(512, 256)
        self.up2 = Up(256, 128)
        self.up3 = Up(128, 64)
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x = self.bottleneck(x3)
        x = self.up1(x, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.out(x)
        return x


In [84]:
unet = UNet(1, 3)

In [85]:
a = torch.randn(1, 1, 256, 256)
b = unet(a)

In [86]:
b.shape

torch.Size([1, 3, 256, 256])