In [1]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(channels, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(channels, affine=True, track_running_stats=True),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, num_filters=64, num_res_blocks=9):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            # Initial Convolution Block
            nn.Conv2d(3, num_filters, kernel_size=7, padding=3),
            nn.InstanceNorm2d(num_filters),
            nn.ReLU(inplace=True),

            # 2 downsampling layers (256 to 128 to 64)
            nn.Conv2d(num_filters, num_filters * 2, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(num_filters * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_filters * 2, num_filters * 4, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(num_filters * 4),
            nn.ReLU(inplace=True),

            # 9 resnet blocks as per the original paper
            *[ResidualBlock(num_filters * 4) for _ in range(num_res_blocks)],

            # 2 upsampling layers (64 to 128 to 256)
            nn.ConvTranspose2d(num_filters * 4, num_filters * 2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(num_filters * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(num_filters * 2, num_filters, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(num_filters),
            nn.ReLU(inplace=True),

            # Output layer, 3 channel (RGB)
            nn.Conv2d(num_filters, 3, kernel_size=7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# discriminator
class Discriminator(nn.Module):
    def __init__(self, num_filters=64):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # C64 layer (no InstanceNorm)
            nn.Conv2d(3, num_filters, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            # C128 layer
            nn.Conv2d(num_filters, num_filters * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(num_filters * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # C256 layer
            nn.Conv2d(num_filters * 2, num_filters * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(num_filters * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # C512 layer
            nn.Conv2d(num_filters * 4, num_filters * 8, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(num_filters * 8),
            nn.LeakyReLU(0.2, inplace=True),

            # 1D output layer
            nn.Conv2d(num_filters * 8, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# loss functions
class CycleGANLosses:
    def __init__(self, device):
        self.device = device
        self.adversarial_loss = nn.MSELoss()
        self.cycle_loss = nn.L1Loss()
        self.identity_loss = nn.L1Loss()

    def compute_g_loss(self, fake, real, cycle, identity, lambda_cyc, lambda_id):
        # Adversarial Loss
        g_loss = self.adversarial_loss(fake, torch.ones_like(fake).to(self.device))
        
        # Cycle Consistency Loss
        g_loss += lambda_cyc * self.cycle_loss(cycle, real)
        
        # Identity Loss
        g_loss += lambda_id * self.identity_loss(identity, real)

        return g_loss

    def compute_d_loss(self, real, fake):
        # Real Loss
        real_loss = self.adversarial_loss(real, torch.ones_like(real).to(self.device))
        
        # Fake Loss
        fake_loss = self.adversarial_loss(fake, torch.zeros_like(fake).to(self.device))
        
        return (real_loss + fake_loss) / 2


def initialize_models(device):
    g_ab = Generator().to(device)
    g_ba = Generator().to(device)
    d_a = Discriminator().to(device)
    d_b = Discriminator().to(device)
    return g_ab, g_ba, d_a, d_b