In [None]:
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
%pip install Pillow==9.4.0

In [None]:
import os, glob
from dataclasses import dataclass, asdict
from typing import Any, List, Optional, Tuple, Union

from PIL import Image

import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as fn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

In [None]:
class DownConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel: int = 4,
        stride: int = 2,
        padding: int = 1,
        activation: str = "leaky_relu",
        do_batch_norm: bool = True,
        num_groups: int = 32,
        negative_slope = 0.2,
    ):
        super().__init__()

        self.norm = nn.GroupNorm(num_groups, in_channels) if do_batch_norm else None
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel, stride=stride, padding=padding
        )

        match activation.lower():
            case "leaky_relu":
                self.act_fn = nn.LeakyReLU(negative_slope=negative_slope)
            case "identity":
                self.act_fn = nn.Identity()
            case _:
                raise NotImplementedError(f"`activation` must be `leaky_relu` or `identity`")

    def forward(self, x: Tensor) -> Tensor:
        if self.norm:
            x = self.norm(x)
        x = self.conv(x)
        x = self.act_fn(x)
        return x

class Encoder(nn.Module):
    def __init__(
        self,
        *,
        in_channels: int,
        down_out_channels: Tuple[int],
        kernels: Union[int, Tuple[int]],
        strides: Union[int, Tuple[int]],
        paddings: Union[int, Tuple[int]],
        do_batch_norms: Union[bool, Tuple[bool]],
        activations: Union[str, Tuple[str]],
    ):
        super().__init__()

        # check inputs
        num_blocks = len(down_out_channels)
        if not isinstance(kernels, int) and len(kernels) != num_blocks:
            raise ValueError("`kernels` must have the same length as `down_out_channels`")
        if not isinstance(strides, int) and len(strides) != num_blocks:
            raise ValueError("`strides` must have the same length as `down_out_channels`")
        if not isinstance(paddings, int) and len(paddings) != num_blocks:
            raise ValueError("`paddings` must have the same length as `down_out_channels`")
        if not isinstance(do_batch_norms, bool) and len(do_batch_norms) != num_blocks:
            raise ValueError("`do_batch_norms` must have the same length as `down_out_channels`")
        if not isinstance(activations, str) and len(activations) != num_blocks:
            raise ValueError("`activations` must have the same length as `down_out_channels`")

        if isinstance(kernels, int):
            kernels = (kernels,) * num_blocks
        if isinstance(strides, int):
            strides = (strides,) * num_blocks
        if isinstance(paddings, int):
            paddings = (paddings,) * num_blocks
        if isinstance(do_batch_norms, bool):
            do_batch_norms = (do_batch_norms,) * num_blocks
        if isinstance(activations, str):
            activations = (activations,) * num_blocks

        self.down_blocks = nn.Sequential()
        for i in range(num_blocks):
            out_channels = down_out_channels[i]
            self.down_blocks.append(
                DownConv(
                    in_channels,
                    out_channels,
                    kernel=kernels[i],
                    stride=strides[i],
                    padding=paddings[i],
                    activation=activations[i],
                    do_batch_norm=do_batch_norms[i],
                )
            )
            in_channels = out_channels

    def forward(self, x: Tensor) -> Tensor:
        x = self.down_blocks(x)
        return x

In [None]:
class UpConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel: int = 4,
        stride: int = 2,
        padding: int = 1,
        activation: str = "relu",
        do_batch_norm: bool = True,
        num_groups: int = 32,
    ):
        super().__init__()

        self.norm = nn.GroupNorm(num_groups, in_channels) if do_batch_norm else None
        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel, stride=stride, padding=padding
        )

        match activation.lower():
            case "relu":
                self.act_fn = nn.ReLU()
            case "tanh":
                self.act_fn = nn.Tanh()
            case _:
                raise NotImplementedError(f"`activation` must be `relu` or `tanh`")

    def forward(self, x: Tensor) -> Tensor:
        if self.norm:
            x = self.norm(x)
        x = self.conv(x)
        x = self.act_fn(x)
        return x

class Decoder(nn.Module):
    def __init__(
        self,
        *,
        in_channels: int,
        up_out_channels: Tuple[int],
        kernels: Union[int, Tuple[int]],
        strides: Union[int, Tuple[int]],
        paddings: Union[int, Tuple[int]],
        do_batch_norms: Union[bool, Tuple[bool]],
        activations: Union[str, Tuple[str]],
    ):
        super().__init__()

        # check inputs
        num_blocks = len(up_out_channels)
        if not isinstance(kernels, int) and len(kernels) != num_blocks:
            raise ValueError("`kernels` must have the same length as `up_out_channels`")
        if not isinstance(strides, int) and len(strides) != num_blocks:
            raise ValueError("`strides` must have the same length as `up_out_channels`")
        if not isinstance(paddings, int) and len(paddings) != num_blocks:
            raise ValueError("`paddings` must have the same length as `up_out_channels`")
        if not isinstance(do_batch_norms, int) and len(do_batch_norms) != num_blocks:
            raise ValueError("`do_batch_norms` must have the same length as `up_out_channels`")
        if not isinstance(activations, int) and len(activations) != num_blocks:
            raise ValueError("`activations` must have the same length as `up_out_channels`")

        if isinstance(kernels, int):
            kernels = (kernels,) * num_blocks
        if isinstance(strides, int):
            strides = (strides,) * num_blocks
        if isinstance(paddings, int):
            paddings = (paddings,) * num_blocks
        if isinstance(do_batch_norms, int):
            do_batch_norms = (do_batch_norms,) * num_blocks
        if isinstance(activations, int):
            activations = (activations,) * num_blocks

        self.up_blocks = nn.Sequential()
        for i in range(num_blocks):
            out_channels = up_out_channels[i]
            self.up_blocks.append(
                UpConv(
                    in_channels,
                    out_channels,
                    kernel=kernels[i],
                    stride=strides[i],
                    padding=paddings[i],
                    activation=activations[i],
                    do_batch_norm=do_batch_norms[i],
                )
            )
            in_channels = out_channels

    def forward(self, x: Tensor) -> Tensor:
        x = self.up_blocks(x)
        return x

In [None]:
@dataclass
class EncoderConfig:
    in_channels: int = 3
    down_out_channels: Tuple[int] = (64, 128, 256, 512, 512)
    kernels: Union[int, Tuple[int]] = 4
    strides: Union[int, Tuple[int]] = (2, 2, 2, 2, 1)
    paddings: Union[int, Tuple[int]] = (1, 1, 1, 1, 0)
    do_batch_norms: Union[bool, Tuple[bool]] = (False, True, True, True, False)
    activations: Union[str, Tuple[str]] = (
        "leaky_relu",
        "leaky_relu",
        "leaky_relu",
        "leaky_relu",
        "identity",
    )

@dataclass
class DecoderConfig:
    in_channels: int = 512
    up_out_channels: Tuple[int] = (512, 256, 128, 64, 3)
    kernels: Union[int, Tuple[int]] = 4
    strides: Union[int, Tuple[int]] = (1, 2, 2, 2, 2)
    paddings: Union[int, Tuple[int]] = (0, 1, 1, 1, 1)
    do_batch_norms: Union[bool, Tuple[bool]] = (True, True, True, True, False)
    activations: Union[str, Tuple[str]] = (
        "relu",
        "relu",
        "relu",
        "relu",
        "tanh",
    )

class CelebAModel(nn.Module):
    def __init__(
        self,
        encoder_config: EncoderConfig,
        decoder_config: DecoderConfig,
        do_init: bool = True,
        init_mean: float = 0.0,
        init_std: float = 0.02,
        init_const: float = 0.0,
    ):
        super().__init__()

        if decoder_config.in_channels != encoder_config.down_out_channels[-1]:
            raise ValueError(
                "`in_channels` for decoder must be the same as the last element of `down_out_channels` for encoder"
            )

        self.encoder = Encoder(**asdict(encoder_config))
        self.decoder = Decoder(**asdict(decoder_config))

        if do_init:
            self._initialize_params(mean=init_mean, std=init_std, const=init_const)

    def forward(self, x: Tensor) -> Tensor:
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def _initialize_params(self, mean: float, std: float, const: float):
        def init_params(module: nn.Module):
            if isinstance(module, Union[nn.Conv2d, nn.ConvTranspose2d]):
                torch.nn.init.normal_(module.weight, mean=mean, std=std)
                torch.nn.init.constant_(module.bias, val=const)

        self.apply(init_params)


In [None]:
class ImageDataset(Dataset):
    def __init__(
        self,
        data_dir: Union[str, os.PathLike],
        resolution: int = 64,
        center_crop: bool = True,
        ext: str = "jpg",
    ):
        self.images = sorted(
            [f for f in glob.glob(os.path.join(data_dir, f"*.{ext}"))]
        )
        self.pre_proc = transforms.Compose(
            [
                transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution),
                transforms.ToTensor(),
            ]
        )

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> Tensor:
        img = Image.open(self.images[idx]).convert("RGB")
        return self.pre_proc(img)

def collate_fn(image_batch: List[Tensor]) -> Tensor:
    pixel_values = torch.stack(image_batch)
    return pixel_values.to(memory_format=torch.contiguous_format).float()


In [None]:
def get_freq_means_and_stds(x: Tensor) -> Tuple[Tensor]:
    freq = torch.fft.fft2(x)
    real_mean = freq.real.mean(dim=0)
    real_std = freq.real.std(dim=0)
    imag_mean = freq.imag.mean(dim=0)
    imag_std = freq.imag.std(dim=0)
    return real_mean, real_std, imag_mean, imag_std

def get_noise(
    real_mean: Tensor,
    real_std: Tensor,
    imag_mean: Tensor,
    imag_std: Tensor,
) -> Tensor:
    freq_real = torch.normal(real_mean, real_std)
    freq_imag = torch.normal(imag_mean, imag_std)
    freq = freq_real + 1j * freq_imag
    noise = torch.fft.ifft2(freq)
    return noise.real

In [None]:
batch_size = 256
lr = 1e-4
adam_betas = (0.5, 0.999)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

enc_config = EncoderConfig()
dec_config = DecoderConfig()
model = CelebAModel(enc_config, dec_config, do_init=True)
model_copy = CelebAModel(enc_config, dec_config, do_init=False).requires_grad_(False)
model.to(device)
model_copy.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas)

dataset = ImageDataset("./img_align_celeba")
generator = torch.Generator().manual_seed(123)
train_dataset, test_dataset = random_split(dataset, [0.9, 0.1], generator=generator)
train_dl = DataLoader(
    train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True
)
test_dl = DataLoader(
    test_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False
)

In [None]:
rec_weight = 20
idem_weight = 20
tight_weight = 2.5
idem_weight /= rec_weight
tight_weight /= rec_weight
loss_tight_clamp_ratio = 1.5

last_epoch = 0
train_loss_hist = []

checkpoint_dir = "ign-celeba"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_name: Optional[str] = None
if isinstance(checkpoint_name, str):
    path = os.path.join(checkpoint_dir, checkpoint_name)
    checkpoint = torch.load(path)
    last_epoch = checkpoint["epoch"]
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optim_state_dict"])
    train_loss_hist = checkpoint["train_loss_hist"]

num_epochs = 200
save_interval = 5

for e in range(num_epochs):
    model.train()
    train_loss = 0.0
    for x in train_dl:
        bsz = x.shape[0]
        x = x.to(device)
        # normalize
        x = 2.0 * x - 1.0

        # get noise from input frequency statistics
        freq_means_and_stds = get_freq_means_and_stds(x)
        z = torch.stack([get_noise(*freq_means_and_stds) for _ in range(bsz)])
        z = z.to(device, memory_format=torch.contiguous_format)

        # compute model outputs
        model_copy.load_state_dict(model.state_dict())
        fx = model(x)
        fz = model(z)
        f_z = fz.detach()
        ff_z = model(f_z)
        f_fz = model_copy(fz)

        # compute losses
        loss_rec = fn.l1_loss(fx, x, reduction="none").view(bsz, -1).mean(dim=-1)
        loss_idem = fn.l1_loss(f_fz, fz, reduction="mean")
        loss_tight = -fn.l1_loss(ff_z, f_z, reduction="none").view(bsz, -1).mean(dim=-1)
        loss_tight_clamp = loss_tight_clamp_ratio * loss_rec
        loss_tight = fn.tanh(loss_tight / loss_tight_clamp) * loss_tight_clamp
        loss_rec = loss_rec.mean()
        loss_tight = loss_tight.mean()

        loss = loss_rec + idem_weight * loss_idem + tight_weight * loss_tight

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * bsz

    train_loss /= len(train_dl.dataset)
    train_loss_hist.append(train_loss)

    epoch = e + last_epoch + 1
    print(f"Epoch {epoch} loss: {train_loss:.4f}")
    # save checkpoint
    if epoch % save_interval == 0 or e == num_epochs - 1:
        path = os.path.join(checkpoint_dir, f"epoch_{epoch}.pt")
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                "train_loss_hist": train_loss_hist,
            },
            path,
        )


In [None]:
@torch.no_grad
def generate(
    model: CelebAModel,
    num_steps: int,
    sample_x: Tensor,
    device: Any = torch.device("cpu"),
) -> List[Image.Image]:
    model.eval()
    model.to(device)

    # generate image
    freq_means_and_stds = get_freq_means_and_stds(sample_x)
    z = get_noise(*freq_means_and_stds).unsqueeze(0)
    images = []
    for _ in range(num_steps):
        z = z.to(device)
        z = model(z)
        images.append(z.squeeze(0))

    # denormalize
    images = [img / 2 + 0.5 for img in images]
    # to numpy arrays
    images = [img.cpu().permute(1, 2, 0).float().numpy() for img in images]
    # to PIL image
    images = [(255 * img).round().astype("uint8") for img in images]
    images = [Image.fromarray(img) for img in images]
    return images

In [None]:
num_images = 50
num_steps = 1
for i, sample_x in enumerate(test_dl):
    if i == num_images:
        break
    images = generate(model, num_steps, sample_x, device)
    for j, img in enumerate(images):
        img.save(f"gen{i}_step{j+1}.png")