In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from models_unet import Discriminator, Generator, initialize_weights
import torch.optim as optim
import torchvision
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [76]:
import torch
import torch.nn as nn

class Unet(nn.Module):
    def __init__(self, in_channels, out_channels, features):
        super(Unet, self).__init__()

        # Encoder
        self.enc1 = self._conv_block(in_channels, features, 4, 2, 1)  # Input: [3, 64, 64], Output: [features, 32, 32]
        self.enc2 = self._conv_block(features, features * 2, 4, 2, 1)  # Input: [features, 32, 32], Output: [features*2, 16, 16]
        self.enc3 = self._conv_block(features * 2, features * 4, 4, 2, 1)  # Input: [features*2, 16, 16], Output: [features*4, 8, 8]

        # Decoder
        self.dec1 = self._tconv_block(features * 4, features * 2, 4, 2, 1)  # Input: [features*4, 8, 8], Output: [features*2, 16, 16]
        self.dec2 = self._tconv_block(features * 4, features, 4, 2, 1)  # Input: [features*4, 16, 16], Output: [features, 32, 32]
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )  # Input: [features*2, 32, 32], Output: [out_channels, 64, 64]

    def _conv_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def _tconv_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoding path
        enc1 = self.enc1(x)  # Output shape: [features, 32, 32]
        enc2 = self.enc2(enc1)  # Output shape: [features*2, 16, 16]
        enc3 = self.enc3(enc2)  # Output shape: [features*4, 8, 8]

        # Decoding path
        dec1 = self.dec1(enc3)  # Output shape: [features*2, 16, 16]
        dec1 = torch.cat((dec1, enc2), dim=1)  # Output shape: [features*4, 16, 16]
        dec2 = self.dec2(dec1)  # Output shape: [features, 32, 32]
        dec2 = torch.cat((dec2, enc1), dim=1)  # Output shape: [features*2, 32, 32]
        out = self.dec3(dec2)  # Output shape: [out
        return out

In [74]:
gen = Unet(3, 3, 32)
a = torch.ones(1, 3, 64, 64)
print(gen(a).shape)  # Should print torch.Size([1, 3, 64, 64])

torch.Size([1, 128, 8, 8])
torch.Size([1, 3, 64, 64])
