In [1]:
import os
import torch
import torch.nn as nn
# CUDA for PyTorch
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
%cd stylegan2-ada-pytorch

!pip install ninja imageio-ffmpeg

!pip install requests

!wget -q https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-stylegan2-noaug.pkl \
      -O ffhq256.pkl

!wget --content-disposition \
  "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-256x256.pkl" \
  -O ffhq256.pkl

!ls -lh ffhq256.pkl

In [None]:
import pickle
import torch

device = 'cuda'

with open('ffhq256.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)  

In [None]:
z = torch.randn(1, G.z_dim).to(device)
w = G.mapping(z, None)  # latent space (can be w or w+)

# noise_mode \in ['const', 'random', 'none']

img = G.synthesis(w, noise_mode='const')  # [1, 3, 256, 256]

In [None]:
# CelebA-HQ 256x256 Download

import os
import zipfile
import shutil

# 1. kaggle.json file
kaggle_json_path = "/home/user_yjs/DGM_Project/kaggle.json" 
kaggle_dir = os.path.expanduser("~/.kaggle")
os.makedirs(kaggle_dir, exist_ok=True)
shutil.copy(kaggle_json_path, os.path.join(kaggle_dir, "kaggle.json"))
os.chmod(os.path.join(kaggle_dir, "kaggle.json"), 0o600)

# 2. import kaggle library
try:
    import kaggle
except ImportError:
    import subprocess
    subprocess.check_call(["pip", "install", "kaggle"])
    import kaggle

# 3. Download & Unzip
dataset_name = "badasstechie/celebahq-resized-256x256"
zip_name = "celebahq-resized-256x256.zip"
output_dir = "celebahq_256"

# Download
if not os.path.exists(zip_name):
    os.system(f"kaggle datasets download -d {dataset_name}")

# Unzip
if not os.path.exists(output_dir):
    with zipfile.ZipFile(zip_name, "r") as zip_ref:
        zip_ref.extractall(output_dir)

print(f"Download and unzip completed: {output_dir}")

In [None]:
from PIL import Image
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torch.nn.functional as F

# imgae path
img_path = "/home/user_yjs/DGM_Project/stylegan2-ada-pytorch/celebahq_256/celeba_hq_256/00013.jpg"

# image load and post proccessing
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),  # [0, 1]
])
img = Image.open(img_path).convert("RGB")
img_tensor = transform(img)  # (3, 256, 256)

# centered mask
mask = torch.zeros(1, 256, 256)
mask[:, 80:176, 80:176] = 1.0
masked_img_tensor = img_tensor * (1 - mask)

# Encoder input: masked image + mask
encoder_input = torch.cat([masked_img_tensor, mask], dim=0).unsqueeze(0)  # shape: (1, 4, 256, 256)

# Visualization
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(img_tensor.permute(1, 2, 0).numpy())
plt.title("Original")

plt.subplot(1, 3, 2)
plt.imshow(mask[0], cmap='gray')
plt.title("Mask")

plt.subplot(1, 3, 3)
plt.imshow(masked_img_tensor.permute(1, 2, 0).numpy())
plt.title("Masked Image")

plt.show()

In [None]:
import torch
import torch.nn as nn

# Conv and downsample block
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, 1, 1),
            nn.GroupNorm(32, out_c),
            nn.ReLU(True),
            nn.Conv2d(out_c, out_c, 3, 1, 1),
            nn.GroupNorm(32, out_c),
            nn.ReLU(True),
        )
    def forward(self, x):  return self.block(x)

class DownsampleBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 4, 2, 1),
            nn.GroupNorm(32, out_c),
            nn.ReLU(True),
        )
    def forward(self, x):  return self.block(x)

# Encoder
class InvertFillEncoder(nn.Module):
    def __init__(self, in_channels: int = 4):
        super().__init__()

        # extracting feature
        self.conv1 = ConvBlock(in_channels, 64)        # res0 : 256×256
        self.down1 = DownsampleBlock(64, 128)          # res1 : 128×128
        self.down2 = DownsampleBlock(128, 256)         # res2 :  64×64
        self.down3 = DownsampleBlock(256, 256)         # res3 :  32×32
        self.down4 = DownsampleBlock(256, 512)         # res4 :  16×16
        self.down5 = DownsampleBlock(512, 512)         # res5 :   8×8
        self.down6 = DownsampleBlock(512, 512)         # res6 :   4×4   

        # multi-scale RGB heads  
        self.rgb_heads = nn.ModuleList([
            nn.Conv2d( 64, 3, 1),   # res0 256×256
            nn.Conv2d(128, 3, 1),   # res1 128×128
            nn.Conv2d(256, 3, 1),   # res2  64×64
            nn.Conv2d(256, 3, 1),   # res3  32×32
            nn.Conv2d(512, 3, 1),   # res4  16×16
            nn.Conv2d(512, 3, 1),   # res5   8×8
            nn.Conv2d(512, 3, 1),   # res6   4×4
        ])

        # RGB heads
        self.rgb1 = nn.Conv2d(128, 3, 1)   # from res1 (128×128)
        self.rgb2 = nn.Conv2d(256, 3, 1)   # from res3 (32×32)
        self.rgb3 = nn.Conv2d(512, 3, 1)   # from res5 (8×8)

    # ─────────────────────────────────────────
    def forward(self, x):
        feats = {}

        x = self.conv1(x);  feats["res0"] = x          # 256
        x = self.down1(x);  feats["res1"] = x          # 128
        x = self.down2(x);  feats["res2"] = x          #  64
        x = self.down3(x);  feats["res3"] = x          #  32
        x = self.down4(x);  feats["res4"] = x          #  16
        x = self.down5(x);  feats["res5"] = x          #   8
        x = self.down6(x);  feats["res6"] = x          #   4

        # multi-scale RGB projection
        Or_E = [head(feats[f"res{i}"]) for i, head in enumerate(self.rgb_heads)]

        return feats["res6"], [feats[f"res{i}"] for i in range(7)], Or_E

In [None]:
import torch.nn as nn

# Feature >> style vector
class Map2Style(nn.Module):
    def __init__(self, in_channels: int, style_dim: int = 512):
        super().__init__()
        self.mapping = nn.Sequential(
            nn.Conv2d(in_channels, 512, 3, 1, 1),
            nn.ReLU(True),
            nn.AdaptiveAvgPool2d(1),   # → [B, 512, 1, 1]
            nn.Flatten(),              # → [B, 512]
            nn.Linear(512, style_dim)  # → [B, style_dim]
        )
    def forward(self, x):
        return self.mapping(x)         # [B, 512]

# Style Head
class Map2StyleHead(nn.Module):
    def __init__(self, style_dim: int = 512):
        super().__init__()
        self.map_fine0    = Map2Style( 64,  style_dim)   # res0 256×256
        self.map_fine1    = Map2Style(128,  style_dim)   # res1 128×128
        self.map_middle0  = Map2Style(256,  style_dim)   # res2  64×64
        self.map_middle1  = Map2Style(256,  style_dim)   # res3  32×32
        self.map_coarse0  = Map2Style(512,  style_dim)   # res4  16×16
        self.map_coarse1  = Map2Style(512,  style_dim)   # res5   8×8
        self.map_coarse2  = Map2Style(512,  style_dim)   # res6   4×4

    def forward(self, feats):          
        w_fine0    = self.map_fine0   (feats[0])   # 256
        w_fine1    = self.map_fine1   (feats[1])   # 128
        w_middle0  = self.map_middle0 (feats[2])   #  64
        w_middle1  = self.map_middle1 (feats[3])   #  32
        w_coarse0  = self.map_coarse0 (feats[4])   #  16
        w_coarse1  = self.map_coarse1 (feats[5])   #   8
        w_coarse2  = self.map_coarse2 (feats[6])   #   4

        return [w_coarse2, w_coarse1, w_coarse0,
                w_middle1, w_middle0,
                w_fine1,   w_fine0]                

In [None]:
# RGB proj >> Structure
class Map2Structure(nn.Module):
    def __init__(self, in_channels: int, out_dim: int = 512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 512, 3, 1, 1),
            nn.ReLU(True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, out_dim)
        )
    def forward(self, x):
        return self.encoder(x)         # [B, 512]

# Structure Head
class Map2StructureHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.map_fine0    = Map2Structure(3)  # res0 256
        self.map_fine1    = Map2Structure(3)  # res1 128
        self.map_middle0  = Map2Structure(3)  # res2  64
        self.map_middle1  = Map2Structure(3)  # res3  32
        self.map_coarse0  = Map2Structure(3)  # res4  16
        self.map_coarse1  = Map2Structure(3)  # res5   8
        self.map_coarse2  = Map2Structure(3)  # res6   4

    def forward(self, images):         
        s_fine0    = self.map_fine0   (images[0])
        s_fine1    = self.map_fine1   (images[1])
        s_middle0  = self.map_middle0 (images[2])
        s_middle1  = self.map_middle1 (images[3])
        s_coarse0  = self.map_coarse0 (images[4])
        s_coarse1  = self.map_coarse1 (images[5])
        s_coarse2  = self.map_coarse2 (images[6])

        # coarse(4>>16) >> middle(32>>64) >> fine(128>>256)
        return [s_coarse2, s_coarse1, s_coarse0,
                s_middle1, s_middle0,
                s_fine1,   s_fine0]                

In [None]:
encoder = InvertFillEncoder()
map2structure = Map2StructureHead()
map2style = Map2StyleHead()
x = torch.randn(1, 4, 256, 256)  # [B=1, C=4, H=1024, W=1024]

# x: 4channel masked input image
f_r, feats_for_structure, Or_E = encoder(x)
S_r = map2structure(Or_E)  # S_r = [S_coarse, S_middle, S_fine]

# Style latent vector w' = [w'_0, w'_1, w'_2]
w_primes = map2style(feats_for_structure)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Modulation
class FullPreModulationNetwork(nn.Module):
    """
    StyleGAN2-ADA 256×256:
        - Resolution stages: 4,8,16,32,64,128,256  → 7
        - num of layers      : 7 × 2 conv           → 14
    input
        w_primes : (B, 7, 512)  coarse2 … fine0
        S_r      : (B, 7, 512)
    output
        w*       : (B, 14, 512)
    """
    def __init__(self, style_dim: int = 512):
        super().__init__()
        self.num_layers = 14      
        self.style_dim  = style_dim

        # step-1  : w'  >> Dense × 2 >> IN
        self.w_proj  = nn.Sequential(
            nn.Linear(style_dim, style_dim), nn.ReLU(True),
            nn.Linear(style_dim, style_dim), nn.ReLU(True),
        )
        self.inst_norm = nn.LayerNorm(style_dim)

        # step-2  : [S_r ‖ w_r] >> Dense >> gamma, beta
        self.joint_proj = nn.Linear(style_dim * 2, style_dim)
        self.to_gamma   = nn.Linear(style_dim, style_dim)
        self.to_beta    = nn.Linear(style_dim, style_dim)

        # level_map:  [0,0,1,1,2,2,3,3,4,4,5,5,6,6]
        self.level_map = [i for i in range(7) for _ in range(2)]

    # ----------------------------------------------------------
    def forward(self, w_primes: torch.Tensor, S_r: torch.Tensor):
        """
        Parameters
        ----------
        w_primes : (B, 7, 512)
        S_r      : (B, 7, 512)

        Returns
        -------
        w_star   : (B, 14, 512)
        """
        B = w_primes.size(0)
        w_star_layers = []

        for l in range(self.num_layers):
            r = self.level_map[l]       # 0~6
            w_r = w_primes[:, r, :]     # (B, 512)
            s_r = S_r[:, r, :]          # (B, 512)

            # Step-1  : w' >> projection + instance-norm
            w_proj = self.w_proj(w_r)          # (B, 512)
            w_norm = self.inst_norm(w_proj)

            # Step-2  : [s_r ‖ w_r] >> gamma, beta
            joint  = torch.cat([s_r, w_r], dim=1)   # (B, 1024)
            h      = F.relu(self.joint_proj(joint)) # (B, 512)
            gamma  = self.to_gamma(h)
            beta   = self.to_beta(h)

            # Step-3  : gamma dot IN(w') + beta
            w_star = gamma * w_norm + beta
            w_star_layers.append(w_star)

        return torch.stack(w_star_layers, dim=1)     # (B, 14, 512)

In [None]:
# randomly input
w_primes = torch.randn(1, 7, 512)   # coarse2 … fine0
S_r      = torch.randn(1, 7, 512)

full_premod = FullPreModulationNetwork()
full_premod.eval()

w_star_all = full_premod(w_primes, S_r)      # (1, 14, 512)
print("w_star_all shape:", w_star_all.shape)

In [None]:
# Pre-modulation network implementation
w_star_all = full_premod(w_primes, S_r)  # (1, 18, 512)

# output check
print("w_star_all shape:", w_star_all.shape)

for l in range(w_star_all.shape[1]):
    print(f"w*_{l}:", w_star_all[0, l, :5])  

In [None]:
# Selected sample
w_star_sample = w_star_all.to('cuda')  # Convert to GPU

# Generate image
img = G.synthesis(w_star_sample, noise_mode='const')  # (3, H, W)

In [24]:
device = 'cuda'

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision.models import vgg16, VGG16_Weights

# VGG16 ─ perceptual·style backbone
class SharedVGG(nn.Module):
    """
    Used layers
      relu1_2, relu2_2, relu3_3  
    output: List[Tensor]  length 3
    """
    def __init__(self):
        super().__init__()
        vgg = vgg16(weights=VGG16_Weights.DEFAULT).features
        self.layers = nn.ModuleList([
            vgg[:4],     # relu1_2
            vgg[4:9],    # relu2_2
            vgg[9:16],   # relu3_3
        ])
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, x):
        feats = []
        for layer in self.layers:
            x = layer(x)
            feats.append(x)
        return feats                         # 3-scale feature

# 2. Perceptual  &  Style  Loss
class PerceptualLoss(nn.Module):
    def __init__(self, shared_vgg):
        super().__init__()
        self.vgg = shared_vgg

    def forward(self, inp, tgt):
        loss = 0.
        for f_i, f_t in zip(self.vgg(inp), self.vgg(tgt)):
            loss += F.mse_loss(f_i, f_t)
        return loss


class StyleLoss(nn.Module):
    def __init__(self, shared_vgg):
        super().__init__()
        self.vgg = shared_vgg

    @staticmethod
    def gram(x):
        B, C, H, W = x.shape
        feat = x.view(B, C, -1)
        return torch.bmm(feat, feat.transpose(1, 2)) / (C * H * W)

    def forward(self, inp, tgt):
        loss = 0.
        for f_i, f_t in zip(self.vgg(inp), self.vgg(tgt)):
            loss += F.mse_loss(self.gram(f_i), self.gram(f_t))
        return loss

# Loss of pixel, TV
def reconstruction_loss_valid(I_pred, I_gt, mask):
    """MSE of valid"""
    return F.mse_loss(I_pred * (1 - mask), I_gt * (1 - mask))

def reconstruction_loss_hole(I_pred, I_gt, mask):
    """MSE of hole"""
    return F.mse_loss(I_pred * mask, I_gt * mask)

def tv_loss(I_pred):
    """Total Variation"""
    return (torch.mean(torch.abs(I_pred[:, :, :, :-1] - I_pred[:, :, :,  1:])) +
            torch.mean(torch.abs(I_pred[:, :, :-1, :] - I_pred[:, :,  1:, :])))

# 4. MSR (multi-scale recon) loss
def mse_loss(Im, I_gt_encoder, encoder, map2structure):
    """
    encoder        : InvertFillEncoder (7-scale)
    map2structure  : Map2StructureHead (7-scale)
    """
    _, _, Or_E_m  = encoder(Im)            # List len == 7
    S_r_m          = map2structure(Or_E_m)

    _, _, Or_E_gt = encoder(I_gt_encoder)
    S_r_gt         = map2structure(Or_E_gt)

    loss_Or_E = sum(F.mse_loss(a, b) for a, b in zip(Or_E_m,  Or_E_gt))
    loss_S_r  = sum(F.mse_loss(a, b) for a, b in zip(S_r_m, S_r_gt))
    return loss_Or_E + loss_S_r

# 5. loss dict 
shared_vgg = SharedVGG().to(device)

losses = {
    "valid"      : reconstruction_loss_valid,
    "hole"       : reconstruction_loss_hole,
    "perceptual" : PerceptualLoss(shared_vgg),
    "style"      : StyleLoss(shared_vgg),
    "tv"         : tv_loss,
    "msr"        : mse_loss,
}

In [None]:
encoder = InvertFillEncoder()
map2style = Map2StyleHead()
map2structure = Map2StructureHead()
premod = FullPreModulationNetwork()


with open('ffhq256.pkl', 'rb') as f:
    generator = pickle.load(f)['G_ema'].to(device)  

In [None]:
class InpaintingTrainer:
    """
    End-to-End inpainting trainig loop
      Encoder          : 7 scales (res0 ~ res6)
      Map2Style/Struct : 7 vectors (512 dim)
      PreMod           : 14 layers (coarse2 … fine0, conv×2)
      Generator        : StyleGAN2-ADA 256,  num_ws = 14
    """
    def __init__(self,
                 encoder, map2style, map2structure,
                 premod, generator, losses,
                 device: str = "cuda"):

        self.device        = device
        self.encoder       = encoder.to(device)
        self.map2style     = map2style.to(device)
        self.map2structure = map2structure.to(device)
        self.premod        = premod.to(device)
        self.generator     = generator.to(device)
        self.losses        = losses
  
        self.num_ws = getattr(generator, "num_ws",
                              getattr(generator.synthesis, "num_ws", 14))

        self.optim = torch.optim.Adam(
            list(self.encoder.parameters())       +
            list(self.map2style.parameters())     +
            list(self.map2structure.parameters()) +
            list(self.premod.parameters()),
            lr = 1e-4
        )

    # ----------------------------------------------------------
    def train_step(self, Im, I_gt, I_gt_encoder, mask):
        """
        Parameters
        ----------
        Im           : (B, 4, 256, 256)  ─ [masked RGB + mask]
        I_gt         : (B, 3, 256, 256)  ─ ground truth
        I_gt_encoder : (B, 4, 256, 256)  ─ GT (mask=0) for MSR loss
        mask         : (B, 1, 256, 256)
        """
        Im           = Im.to(self.device)
        I_gt         = I_gt.to(self.device)
        I_gt_encoder = I_gt_encoder.to(self.device)
        mask         = mask.to(self.device)

        # Encoder >> 7 scales feature & O_r^E
        f_r, feats_for_structure, Or_E = self.encoder(Im)               # len 7

        # Extracting Structure / Style
        S_r_list       = self.map2structure(Or_E)                       # 7×[B,512]
        w_primes_list  = self.map2style(feats_for_structure)            # 7×[B,512]

        S_r      = torch.stack(S_r_list,      dim=1)   # (B, 7, 512)
        w_primes = torch.stack(w_primes_list, dim=1)   # (B, 7, 512)

        # Pre-Mod >> w*  (14 layers)
        w_star = self.premod(w_primes, S_r)            # (B, 14, 512)

        # padding / truncation 
        if w_star.shape[1] < self.num_ws:                       # pad
            pad = torch.zeros(w_star.size(0),
                               self.num_ws - w_star.size(1),
                               w_star.size(2),
                               device = w_star.device)
            w_star = torch.cat([w_star, pad], dim=1)
        elif w_star.shape[1] > self.num_ws:                     # truncate
            w_star = w_star[:, :self.num_ws]

        # Synthesis
        I_pred = self.generator.synthesis(w_star, noise_mode="const")   # (B,3,256,256)
        I_pred = (I_pred + 1) / 2.0                                     # [-1,1] → [0,1]

        # Losses
        L_valid  = 10.0  * self.losses["valid"](I_pred, Im[:, :3], mask)
        L_hole   =         self.losses["hole"] (I_pred, Im[:, :3], mask)
        L_perc   =         self.losses["perceptual"](I_pred, I_gt)
        L_style  = 1e6   * self.losses["style"](I_pred, I_gt)
        L_tv     =         self.losses["tv"](I_pred)
        L_msr    = 1e3   * self.losses["msr"](Im, I_gt_encoder,
                                              self.encoder, self.map2structure)

        L_total = L_valid + L_hole + L_perc + L_style + L_tv + L_msr

        # Back-propagation
        self.optim.zero_grad(set_to_none=True)
        L_total.backward()
        self.optim.step()

        return {
            "loss_total" : L_total.item(),
            "loss_valid" : L_valid.item(),
            "loss_hole"  : L_hole.item(),
            "loss_perc"  : L_perc.item(),
            "loss_style" : L_style.item(),
            "loss_tv"    : L_tv.item(),
            "loss_msr"   : L_msr.item(),
        }

In [28]:
trainer = InpaintingTrainer(
    encoder=encoder,
    map2style=map2style,
    map2structure=map2structure,
    premod=premod,
    generator=generator,
    losses=losses,
    device='cuda'
)

In [None]:
from tqdm import trange

def train_epochs(trainer, data_list, num_epochs=5000):
    logs = {
        'loss_total': [], 'loss_valid': [], 'loss_hole': [],
        'loss_perc': [], 'loss_style': [], 'loss_tv': [], 'loss_msr': []
    }

    # tqdm range instead of print per epoch
    pbar = trange(num_epochs, desc="Training", unit="epoch")

    for epoch in pbar:
        epoch_log = {k: 0.0 for k in logs}
        for (Im, I_gt, I_gt_encoder, mask) in data_list:
            log = trainer.train_step(Im, I_gt, I_gt_encoder, mask)
            for k in log:
                if log[k] is not None:
                    epoch_log[k] += log[k]

        for k in logs:
            logs[k].append(epoch_log[k] / len(data_list))

        # log update
        pbar.set_postfix({
            'Total': f"{logs['loss_total'][-1]:.4f}",
            'Recon': f"{(logs['loss_valid'][-1]+logs['loss_hole'][-1]):.4f}",
            'Perc': f"{logs['loss_perc'][-1]:.4f}",
            'Style': f"{logs['loss_style'][-1]:.2e}",
            'TV': f"{logs['loss_tv'][-1]:.4f}",
            'MSR': f"{logs['loss_msr'][-1]:.4f}"
        })

    return logs

def plot_loss_curves(logs):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 6))
    for key, values in logs.items():
        plt.plot(values, label=key)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss Curves")
    plt.legend()
    plt.grid(True)
    plt.show()

def visualize_prediction(I_pred, I_gt, I_masked):
    def to_np(img):
      return img.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()

    def to_np_m(img):
      img = img.detach().cpu()
      img_norm = (img + 1) / 2
      img_norm = torch.clamp(img_norm, 0, 1)
      img_norm_squ = img_norm.squeeze(0)
      img = img_norm_squ.permute(1, 2, 0).numpy()

      return img

    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(to_np(I_masked))
    plt.title("Masked Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(to_np_m(I_pred))
    plt.title("Predicted Image")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(to_np(I_gt))
    plt.title("Ground Truth")
    plt.axis("off")

    plt.show()

In [None]:
import os
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T


# InpaintingDataset
class InpaintingDataset(Dataset):
    def __init__(self, image_dir, mask_generator, transform, mask_size=64):
        self.paths = sorted([
            os.path.join(image_dir, f)
            for f in os.listdir(image_dir)
            if f.lower().endswith(('.jpg', '.png'))
        ])
        self.t      = transform
        self.m_gen  = mask_generator
        self.m_size = mask_size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.t(img)                             # (3, 256, 256)

        mask = self.m_gen(img.shape[1:], self.m_size) # (1, 256, 256)
        Im   = torch.cat([img * (1 - mask), mask], 0) # (4, 256, 256)
        I_gt_enc = torch.cat([img, torch.zeros_like(mask)], 0)

        return Im, img, I_gt_enc, mask
        
def center_square_mask(shape, size=64):
    H, W = shape
    m = torch.zeros(1, H, W)
    s = (H-size)//2
    m[:, s:s+size, s:s+size] = 1.
    return m        

In [34]:
from torch.utils.data import random_split, DataLoader
transform = T.Compose([T.Resize((256,256)), T.ToTensor()])

full_ds = InpaintingDataset("celebahq_256/celeba_hq_256",
                            center_square_mask, transform, 64)

n_train = int(0.02*len(full_ds))
train_ds, val_ds = random_split(full_ds, [n_train, len(full_ds)-n_train])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False)


In [None]:
logs = train_epochs(trainer, train_loader, num_epochs=3000)

In [None]:
plot_loss_curves(logs)

In [None]:
def visualize_prediction(I_pred, I_gt, I_masked):
    def to_np(img):
      return img.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()

    def to_np_m(img):
      img = img.detach().cpu()
      img_norm = (img + 1) / 2
      img_norm = torch.clamp(img_norm, 0, 1)
      img_norm_squ = img_norm.squeeze(0)
      img = img_norm_squ.permute(1, 2, 0).numpy()

      return img

    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(to_np(I_masked))
    plt.title("Masked Input")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(to_np_m(I_pred))
    plt.title("Predicted Image")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(to_np(I_gt))
    plt.title("Ground Truth")
    plt.axis("off")

    plt.show()

In [None]:
from torch.utils.data import Subset, DataLoader

# Extract 10 training dataset
train10 = DataLoader(Subset(train_ds, range(10)), batch_size=1)

for i, (Im, I_gt, I_gt_enc, mask) in enumerate(train10):
    Im = Im.to(device)
    I_gt = I_gt.to(device)
    I_gt_enc = I_gt_enc.to(device)
    mask = mask.to(device)

    with torch.no_grad():
        f_r, feats, Or_E = trainer.encoder(Im)
        S_r = torch.stack(trainer.map2structure(Or_E), dim=1)
        w_p = torch.stack(trainer.map2style(feats), dim=1)
        w_star = trainer.premod(w_p, S_r)
        I_pred = trainer.generator.synthesis(w_star, noise_mode='const')

    # Visualization
    visualize_prediction(I_pred, I_gt, Im[:, :3])


In [None]:
from torch.utils.data import Subset

# Extract 10 test dataset
val10 = DataLoader(Subset(val_ds, range(10)), batch_size=1)

for i, (Im, I_gt, I_gt_enc, mask) in enumerate(val10):
    Im = Im.to(device)
    I_gt = I_gt.to(device)
    I_gt_enc = I_gt_enc.to(device)
    mask = mask.to(device)

    with torch.no_grad():
        f_r, feats, Or_E = trainer.encoder(Im)
        S_r = torch.stack(trainer.map2structure(Or_E), dim=1)
        w_p = torch.stack(trainer.map2style(feats), dim=1)
        w_star = trainer.premod(w_p, S_r)
        I_pred = trainer.generator.synthesis(w_star, noise_mode='const')
        #I_pred = (I_pred + 1) / 2  # Normalize to [0,1]

    visualize_prediction(I_pred, I_gt, Im[:, :3])