In [1]:
import sys
sys.path.append("..")

from general import *
from general.model_utils import *

In [2]:
class Generator(nn.Module):
    def __init__(self, embedding_size = 256, channels = 3, filter_count = np.array([8,4,4,2,1])*64):
        super(Generator, self).__init__()
        self.embedding_size = embedding_size
        self.filter_count = filter_count
        self.channels = channels
        
        self.generator = nn.Sequential(
            CT2D_BN_A(in_channels = self.embedding_size, out_channels = self.filter_count[-1], kernel_size = 5, stride = 2, activation_type = "relu"),
            CT2D_BN_A(in_channels = self.filter_count[-1], out_channels = self.filter_count[-2], kernel_size = 3, stride = 2, activation_type = "relu"),
            CT2D_BN_A(in_channels = self.filter_count[-2], out_channels = self.filter_count[-3], kernel_size = 3, stride = 2, activation_type = "relu"),
            CT2D_BN_A(in_channels = self.filter_count[-3], out_channels = self.filter_count[-4], kernel_size = 3, stride = 2, activation_type = "relu"),
            CT2D_BN_A(in_channels = self.filter_count[-4], out_channels = self.channels, kernel_size = 4, stride = 2, activation_type="tanh"),
        )

    def forward(self, embeddings):
        return self.generator(embeddings)

In [3]:
Generator()(torch.rand(1,256,2,2)).shape

torch.Size([1, 3, 128, 128])

In [97]:
class Discriminator(nn.Module):
    def __init__(self, channels = 3, filter_count = np.array([8,4,4,2,1])*64, classes = 1):
        super(Discriminator, self).__init__()
        self.channels = channels
        self.filter_count = filter_count
        self.classes = classes
        self.discriminator = nn.Sequential(
            C2D_BN_A(in_channels = self.channels, out_channels = self.filter_count[0], kernel_size = 5, stride = 3),
            C2D_BN_A(in_channels = self.filter_count[0], out_channels = self.filter_count[1], kernel_size = 5, stride = 2),
            C2D_BN_A(in_channels = self.filter_count[1], out_channels = self.filter_count[2], kernel_size = 5, stride = 2),
            C2D_BN_A(in_channels = self.filter_count[2], out_channels = self.filter_count[3], kernel_size = 3, stride = 2),
            C2D_BN_A(in_channels = self.filter_count[3], out_channels = self.classes, kernel_size = 3, stride = 2),
            nn.Flatten(),
            nn.Sigmoid()
            )

    def forward(self, images):
        return self.discriminator(images)

In [98]:
Discriminator()(torch.rand(12,3,128,128)).shape

torch.Size([12, 1])