In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# Identity-Enhanced Block
class IdentityEnhancedBlock(nn.Module):
    """
    Simple identity-preserving block:
    - Uses residual feature modulation to preserve structure/texture.
    """
    def __init__(self, ch: int):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.norm = nn.BatchNorm2d(ch)

    def forward(self, x):
        r = F.relu(self.norm(self.conv1(x)))
        r = self.conv2(r)
        return x + 0.5 * r  # residual identity boost



# Generator
class Generator(nn.Module):
    def __init__(self, z_dim=128, img_ch=3, base=64):
        super().__init__()
        self.fc = nn.Linear(z_dim, base*8*14*14)

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(base*8, base*4, 3, padding=1),
            nn.BatchNorm2d(base*4),
            nn.ReLU(True),
        )
        self.id1 = IdentityEnhancedBlock(base*4)

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(base*4, base*2, 3, padding=1),
            nn.BatchNorm2d(base*2),
            nn.ReLU(True),
        )
        self.id2 = IdentityEnhancedBlock(base*2)

        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(base*2, base, 3, padding=1),
            nn.BatchNorm2d(base),
            nn.ReLU(True),
        )
        self.id3 = IdentityEnhancedBlock(base)

        self.out = nn.Sequential(
            nn.Conv2d(base, img_ch, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), -1, 14, 14)
        x = self.up1(x); x = self.id1(x)
        x = self.up2(x); x = self.id2(x)
        x = self.up3(x); x = self.id3(x)
        return self.out(x)


# Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_ch=3, base=64):
        super().__init__()
        def block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            )

        self.net = nn.Sequential(
            block(img_ch, base),
            block(base, base*2),
            block(base*2, base*4),
            block(base*4, base*8),
        )
        self.head = nn.Linear(base*8*14*14, 1)

    def forward(self, x):
        f = self.net(x)
        f = f.view(x.size(0), -1)
        return self.head(f)


class IdentityEncoder(nn.Module):
    """
    Light frozen encoder used for identity constraint.
    """
    def __init__(self, img_ch=3, base=32, out_dim=128):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(img_ch, base, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(base, base*2, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(base*2, base*4, 3, stride=2, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(base*4, out_dim)

    def forward(self, x):
        f = self.enc(x).flatten(1)
        return self.fc(f)
