In [None]:
# Cell 1: Setup and imports
import os
import random
import math
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.ao.quantization as tq
from torch.utils.data import Dataset, DataLoader

# Reproducibility
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


In [None]:
# Cell 2: Data paths and attribute table
base_dir = Path(r"D:\rtx")
img_dir = base_dir / "img_align_celeba" / "img_align_celeba"
attr_file = base_dir / "list_attr_celeba.csv"

subset_dir = base_dir / "subset_celeba"
subset_dir.mkdir(exist_ok=True)

# Load attributes CSV
df = pd.read_csv(attr_file)
df.set_index("image_id", inplace=True)

print("Total images in CelebA:", len(df))

# Select a subset (e.g., first 10,000 rows); adjust as needed
subset_n = 10000
filtered_df = df.head(subset_n).copy()

# Map -1/1 to 0/1
filtered_df = (filtered_df + 1) // 2
filtered_df = filtered_df.astype(np.int64)

print("Subset attribute table shape:", filtered_df.shape)
display(filtered_df.head())

# Copy images to subset folder if not present
copied = 0
for img_name in filtered_df.index:
    src = img_dir / img_name
    dst = subset_dir / img_name
    if src.exists() and not dst.exists():
        Image.open(src).save(dst)  # preserves modes and avoids shutil permission oddities
        copied += 1
print(f" Copied {copied} images to {subset_dir}")


In [None]:
# Cell 3: Dataset and transforms
from torchvision import transforms

attr_names = list(filtered_df.columns)
attr_to_idx = {a:i for i,a in enumerate(attr_names)}

transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.CenterCrop((128,128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
])


class CelebASubset(Dataset):
    def __init__(self, img_dir: Path, attr_df: pd.DataFrame, transform=None):
        self.img_dir = img_dir
        self.attr_df = attr_df
        self.names = list(attr_df.index)
        self.transform = transform
        self.attrs = attr_df.values.astype(np.float32)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        path = self.img_dir / name
        img = Image.open(path).convert("RGB")
        if self.transform: img = self.transform(img)
        attrs = torch.from_numpy(self.attrs[idx])
        return img, attrs

dataset = CelebASubset(subset_dir, filtered_df, transform)
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
len(dataset), len(dataloader)


In [None]:
# Cell 4: Self-attention and fuseable conv blocks
class SelfAttention(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.query = nn.Conv2d(in_ch, in_ch // 8, 1, bias=False)
        self.key   = nn.Conv2d(in_ch, in_ch // 8, 1, bias=False)
        self.value = nn.Conv2d(in_ch, in_ch, 1, bias=False)
        self.gamma = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        b, c, h, w = x.shape
        q = self.query(x).view(b, -1, h*w).permute(0, 2, 1)  # b, hw, cq
        k = self.key(x).view(b, -1, h*w)                     # b, cq, hw
        attn = torch.softmax(torch.bmm(q, k), dim=-1)        # b, hw, hw
        v = self.value(x).view(b, c, h*w)                    # b, c, hw
        out = torch.bmm(v, attn.permute(0, 2, 1)).view(b, c, h, w)
        return self.gamma * out + x

def up_block(cin, cout):
    return nn.Sequential(
        nn.ConvTranspose2d(cin, cout, 4, 2, 1, bias=False),
        nn.BatchNorm2d(cout),
        nn.ReLU(inplace=True)
    )

def down_block(cin, cout, bn=True):
    layers = [nn.Conv2d(cin, cout, 4, 2, 1, bias=not bn)]
    if bn: layers.append(nn.BatchNorm2d(cout))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

class AttrEncoder1x1(nn.Module):
    def __init__(self, attr_dim=40, out_ch=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(attr_dim, 128, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, attrs):  # attrs: (B, attr_dim)
        return self.net(attrs.unsqueeze(-1).unsqueeze(-1))  # (B,out_ch,1,1)


In [None]:
# Cell 5: Attention AttGAN models (128x128)

class AttnAttGANGenerator(nn.Module):
    def __init__(self, z_dim=100, attr_dim=40, base=64, img_ch=3):
        super().__init__()
        self.quant = tq.QuantStub()
        self.dequant = tq.DeQuantStub()
        self.attr = AttrEncoder1x1(attr_dim, out_ch=128)

        # Mix latent+attrs at 1x1
        self.mix = nn.Sequential(
            nn.Conv2d(z_dim + 128, base*8, 1, bias=False),
            nn.BatchNorm2d(base*8),
            nn.ReLU(inplace=True)
        )

        # Upsample chain: 4 -> 8 -> 16 -> 32 -> 64 -> 128
        self.up1 = up_block(base*8, base*4)     # 4->8
        self.up2 = up_block(base*4, base*2)     # 8->16
        self.att16 = SelfAttention(base*2)      # attention at 16x16
        self.up3 = up_block(base*2, base)       # 16->32
        self.att32 = SelfAttention(base)        # attention at 32x32
        self.up4 = up_block(base, base//2)      # 32->64
        self.att64 = SelfAttention(base//2)     # attention at 64x64 (new for 128 target)
        self.up5 = up_block(base//2, base//4)   # 64->128

        # Refine head at 128x128
        self.refine = nn.Sequential(
            nn.Conv2d(base//4, base//4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(base//4),
            nn.ReLU(inplace=True),
            nn.Conv2d(base//4, img_ch, 3, 1, 1),
            nn.Tanh()
        )
        self.z_dim = z_dim

    def forward(self, z, attrs):
        z = z.unsqueeze(-1).unsqueeze(-1)          # (B,z,1,1)
        a = self.attr(attrs)                       # (B,128,1,1)
        x = torch.cat([z, a], 1)
        x = self.quant(x)
        x = self.mix(x)
        x = F.interpolate(x, size=(4,4), mode='nearest')
        x = self.up1(x)
        x = self.up2(x); x = self.att16(x)
        x = self.up3(x); x = self.att32(x)
        x = self.up4(x); x = self.att64(x)
        x = self.up5(x)
        x = self.refine(x)
        x = self.dequant(x)
        return x


class AttnAttGANDiscriminator(nn.Module):
    def __init__(self, attr_dim=40, img_ch=3, base=64):
        super().__init__()
        self.quant = tq.QuantStub()
        self.dequant = tq.DeQuantStub()

        # Attribute encoder (1x1 conv to feature channels)
        self.attr = nn.Sequential(
            nn.Conv2d(attr_dim, base, 1, bias=False),
            nn.BatchNorm2d(base),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Downsample chain: 128 -> 64 -> 32 -> 16 -> 8 -> 4
        self.conv1 = down_block(img_ch + base, base, bn=False)  # 128->64 (no BN)
        self.att64 = SelfAttention(base)                         # attention at 64x64
        self.conv2 = down_block(base, base*2)                    # 64->32
        self.conv3 = down_block(base*2, base*4)                  # 32->16
        self.conv4 = down_block(base*4, base*8)                  # 16->8
        self.conv5 = down_block(base*8, base*8)                  # 8->4 (keep channels)

        # Heads (logits)
        self.adv_head = nn.Conv2d(base*8, 1, 4, 1, 0)
        self.attr_head = nn.Conv2d(base*8, attr_dim, 4, 1, 0)

    def forward(self, img, attrs):
        a = self.attr(attrs.unsqueeze(-1).unsqueeze(-1))
        a = a.expand(-1, -1, img.size(2), img.size(3))  # tile to HxW
        x = torch.cat([img, a], 1)

        x = self.quant(x)
        x = self.conv1(x); x = self.att64(x)
        x = self.conv2(x); x = self.conv3(x); x = self.conv4(x); x = self.conv5(x)
        adv = self.adv_head(x).view(img.size(0), 1)
        attr_logits = self.attr_head(x).view(img.size(0), -1)
        adv = self.dequant(adv); attr_logits = self.dequant(attr_logits)
        return adv, attr_logits


In [None]:
# Cell 6: Fusion utilities for quantization ( for 128x128)

def fuse_for_quant_gen(G: AttnAttGANGenerator):
    # Fuse Conv+BN+Act in mix
    tq.fuse_modules(G.mix, [['0','1','2']], inplace=True)
    # Fuse all upsample blocks: up1..up5 (each is ConvT -> BN -> ReLU)
    for blk in [G.up1, G.up2, G.up3, G.up4, G.up5]:
        tq.fuse_modules(blk, [['0','1','2']], inplace=True)
    # Fuse first trio in refine (Conv -> BN -> ReLU); final Conv+Tanh not fused
    tq.fuse_modules(G.refine, [['0','1','2']], inplace=True)

def fuse_for_quant_disc(D: AttnAttGANDiscriminator):
    # Fuse attribute encoder Conv+BN+Act
    tq.fuse_modules(D.attr, [['0','1','2']], inplace=True)
    # conv1 has no BN (bn=False) -> skip fusion
    # Fuse remaining down blocks that have Conv+BN+Act
    for name in ['conv2', 'conv3', 'conv4', 'conv5']:
        blk = getattr(D, name)
        tq.fuse_modules(blk, [['0','1','2']], inplace=True)


In [None]:
# Cell 7: Instantiate models and optimizers
z_dim = 100
attr_dim = len(attr_names)

G = AttnAttGANGenerator(z_dim=z_dim, attr_dim=attr_dim).to(device)
D = AttnAttGANDiscriminator(attr_dim=attr_dim).to(device)

lr = 2e-4
optimizerG = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

print("Generator params:", sum(p.numel() for p in G.parameters())/1e6, "M")
print("Discriminator params:", sum(p.numel() for p in D.parameters())/1e6, "M")


In [None]:
# Cell 8: Losses and one-epoch training function
bce_logits = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()
lambda_attr = 10.0
lambda_rec = 50.0

def train_one_epoch(G, D, dataloader, optG, optD, device, z_dim=100):
    G.train(); D.train()
    g_running, d_running = 0.0, 0.0
    for imgs, attrs in dataloader:
        imgs = imgs.to(device)
        attrs = attrs.to(device).float()
        bs = imgs.size(0)
        z = torch.randn(bs, z_dim, device=device)
        attrs_tgt = attrs  # could sample/edit attributes here

        # 1 Train D: adv real/fake + attribute on real
        optD.zero_grad(set_to_none=True)
        real_adv, real_attr_logits = D(imgs, attrs)
        d_adv_real = bce_logits(real_adv, torch.ones(bs,1,device=device))
        d_attr_real = bce_logits(real_attr_logits, attrs)
        fake_imgs = G(z, attrs_tgt).detach()
        fake_adv, _ = D(fake_imgs, attrs_tgt)
        d_adv_fake = bce_logits(fake_adv, torch.zeros(bs,1,device=device))
        d_loss = d_adv_real + d_adv_fake + d_attr_real
        d_loss.backward()
        optD.step()

        # 2 Train G: adv to real + attribute on fake + reconstruction (L1)
        optG.zero_grad(set_to_none=True)
        gen_imgs = G(z, attrs_tgt)
        g_adv, g_attr_logits = D(gen_imgs, attrs_tgt)
        g_adv_loss = bce_logits(g_adv, torch.ones(bs,1,device=device))
        g_attr_loss = bce_logits(g_attr_logits, attrs_tgt)
        rec_imgs = G(z, attrs)  # if identity encoder added, use enc(imgs)
        rec_loss = l1(rec_imgs, imgs)
        g_loss = g_adv_loss + lambda_attr * g_attr_loss + lambda_rec * rec_loss
        g_loss.backward()
        optG.step()

        g_running += g_loss.item()
        d_running += d_loss.item()
    return g_running/len(dataloader), d_running/len(dataloader)


In [None]:
# Cell 9: FP32 training with checkpointing every 10 epochs
ckpt_dir = base_dir / "checkpoints_attgan"
ckpt_dir.mkdir(exist_ok=True)

def save_fp32_checkpoint(epoch, G, D, optG, optD, tag="fp32"):
    state = {
        "epoch": epoch,
        "G": G.state_dict(),
        "D": D.state_dict(),
        "optG": optG.state_dict(),
        "optD": optD.state_dict(),
    }
    path = ckpt_dir / f"attgan_{tag}_e{epoch:03d}.pth"
    torch.save(state, path)
    print(f"Saved FP32 checkpoint: {path}")

total_epochs = 30
for epoch in range(1, total_epochs+1):
    g_loss, d_loss = train_one_epoch(G, D, dataloader, optimizerG, optimizerD, device, z_dim)
    print(f"FP32 Epoch {epoch}/{total_epochs} | G: {g_loss:.4f} | D: {d_loss:.4f}")
    if epoch % 10 == 0:
        save_fp32_checkpoint(epoch, G, D, optimizerG, optimizerD, tag="fp32")


In [None]:
# Cell 9.1: Load FP32 checkpoint and generate samples (no saving)
# Adjust checkpoint filename if you want epoch 10 or 20 instead of 30
fp32_ckpt_path = base_dir / "checkpoints_attgan" / "attgan_fp32_e030.pth"
assert fp32_ckpt_path.exists(), f"Missing checkpoint: {fp32_ckpt_path}"

# Recreate models and load FP32 weights
G_fp32 = AttnAttGANGenerator(z_dim=z_dim, attr_dim=len(attr_names)).to(device)
D_fp32 = AttnAttGANDiscriminator(attr_dim=len(attr_names)).to(device)
ckpt = torch.load(fp32_ckpt_path, map_location=device)
G_fp32.load_state_dict(ckpt["G"])
D_fp32.load_state_dict(ckpt["D"])
G_fp32.eval(); D_fp32.eval()

# Generate a batch from random attributes
with torch.inference_mode():
    B = 8
    rand_attrs = torch.randint(0, 2, (B, len(attr_names)), device=device, dtype=torch.float32)
    z = torch.randn(B, z_dim, device=device)
    gen_imgs_fp32 = G_fp32(z, rand_attrs)  # [-1,1], shape (B,3,128,128)

# Optional: generate with specific attributes by name
name_to_idx = {n:i for i,n in enumerate(attr_names)}
def make_attr_vec(pairs):
    v = torch.zeros(len(attr_names), device=device)
    for k, val in pairs.items():
        if k in name_to_idx: v[name_to_idx[k]] = float(val)
    return v

pairs = {"Smiling":1, "Blond_Hair":1, "Male":0, "Young":1}
with torch.inference_mode():
    attr_vec = make_attr_vec(pairs).unsqueeze(0).repeat(8, 1)
    z_sp = torch.randn(8, z_dim, device=device)
    gen_imgs_fp32_spec = G_fp32(z_sp, attr_vec)  # [-1,1]

print("FP32 gen random:", tuple(gen_imgs_fp32.shape), "FP32 gen specific:", tuple(gen_imgs_fp32_spec.shape))


In [None]:
# Cell 10: QAT preparation (fuse + prepare_qat)
# Choose engine: "fbgemm" on x86, "qnnpack" on ARM
engine = "fbgemm" if (os.name == "nt" or os.name == "posix") else "fbgemm"
torch.backends.quantized.engine = engine
print("Quant engine:", torch.backends.quantized.engine)

G.train(); D.train()
fuse_for_quant_gen(G); fuse_for_quant_disc(D)

G.qconfig = tq.get_default_qat_qconfig(engine)
D.qconfig = tq.get_default_qat_qconfig(engine)

tq.prepare_qat(G, inplace=True)
tq.prepare_qat(D, inplace=True)

print("QAT observers inserted.")


In [None]:
# Cell 11: QAT fine-tuning with checkpointing every 10 epochs
# Optionally reduce LR for QAT
for pg in optimizerG.param_groups: pg["lr"] = min(pg["lr"], 1e-4)
for pg in optimizerD.param_groups: pg["lr"] = min(pg["lr"], 1e-4)

qat_epochs = 30
for ep in range(1, qat_epochs+1):
    # Optional: freeze observers/BN late
    if ep == 20:
        G.apply(torch.ao.quantization.disable_observer)
        D.apply(torch.ao.quantization.disable_observer)
    if ep == 25:
        G.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
        D.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    g_loss, d_loss = train_one_epoch(G, D, dataloader, optimizerG, optimizerD, device, z_dim)
    print(f"QAT Epoch {ep}/{qat_epochs} | G: {g_loss:.4f} | D: {d_loss:.4f}")

    if ep % 10 == 0:
        state = {
            "epoch": ep,
            "G_qat": G.state_dict(),
            "D_qat": D.state_dict(),
            "optG": optimizerG.state_dict(),
            "optD": optimizerD.state_dict(),
        }
        path = ckpt_dir / f"attgan_qat_e{ep:03d}.pth"
        torch.save(state, path)
        print(f"Saved QAT checkpoint: {path}")


In [None]:
# Cell 12: Convert to INT8 for inference
G.eval(); D.eval()
tq.convert(G, inplace=True)
tq.convert(D, inplace=True)
print("Converted to INT8 (fake-quant removed, quantized ops inserted).")


In [None]:
# Cell 13 (updated): Load INT8 generator and generate samples (no saving)
# If you just ran Cell 12 in this session, you can directly use the in-memory G.
# Otherwise, to load from saved INT8 weights, recreate quantized model as below.

int8_G_path = base_dir / "checkpoints_qat" / "G_qat_int8.pt"
int8_D_path = base_dir / "checkpoints_qat" / "D_qat_int8.pt"
assert int8_G_path.exists(), f"Missing quantized G: {int8_G_path}"

# Recreate models, fuse, prepare, convert to quantized graph, then load int8 weights
G_int8 = AttnAttGANGenerator(z_dim=z_dim, attr_dim=len(attr_names)).to(device)
D_int8 = AttnAttGANDiscriminator(attr_dim=len(attr_names)).to(device)
torch.backends.quantized.engine = "fbgemm"
fuse_for_quant_gen(G_int8); fuse_for_quant_disc(D_int8)
G_int8.qconfig = tq.get_default_qat_qconfig("fbgemm")
D_int8.qconfig = tq.get_default_qat_qconfig("fbgemm")
tq.prepare_qat(G_int8, inplace=True); tq.prepare_qat(D_int8, inplace=True)
G_int8.eval(); D_int8.eval()
tq.convert(G_int8, inplace=True); tq.convert(D_int8, inplace=True)

# Load INT8 state dicts (optional for D)
G_int8.load_state_dict(torch.load(int8_G_path, map_location=device), strict=True)

# Generate a batch from random attributes
with torch.inference_mode():
    B = 8
    rand_attrs = torch.randint(0, 2, (B, len(attr_names)), device=device, dtype=torch.float32)
    z = torch.randn(B, z_dim, device=device)
    gen_imgs_int8 = G_int8(z, rand_attrs)  # [-1,1], shape (B,3,128,128)
print("INT8 gen random:", tuple(gen_imgs_int8.shape), gen_imgs_int8.dtype)


In [None]:
# Cell 14: Save quantized models
ckpt_dir = base_dir / "checkpoints_qat"
ckpt_dir.mkdir(exist_ok=True)
torch.save(G.state_dict(), ckpt_dir / "G_qat_int8.pt")
torch.save(D.state_dict(), ckpt_dir / "D_qat_int8.pt")
print("Saved INT8 checkpoints to:", ckpt_dir)


In [None]:
# Cell 15: Free-form attribute input -> generate (no saving)

import re
import matplotlib.pyplot as plt

def to_display(img_t):  # (3,H,W) in [-1,1] -> (H,W,3) in [0,1] for imshow
    return ((img_t.clamp(-1,1) + 1) / 2.0).permute(1,2,0).cpu().numpy()

# Build a name->index map once
name_to_idx = {n.lower(): i for i, n in enumerate(attr_names)}

def parse_freeform_attrs(s: str) -> torch.Tensor:
    """
    Parse a free-form string like:
    - 'Smiling, Blond_Hair, -Male, Young'
    - '+Smiling +Blond_Hair -Male +Young'
    - 'Smiling Blond_Hair !Male Young'
    Returns a (attr_dim,) float tensor with 0/1 values.
    Unspecified attributes default to 0.
    """
    v = torch.zeros(len(attr_names), device=device)
    if not s or not s.strip():
        return v
    # Split on comma or whitespace
    tokens = re.split(r"[,\s]+", s.strip())
    for tok in tokens:
        if not tok:
            continue
        sign = +1
        raw = tok
        if tok in "+-!":
            sign = -1 if tok in "-!" else +1
            raw = tok[1:]
        key = raw.strip().lower()
        if key in name_to_idx:
            v[name_to_idx[key]] = 1.0 if sign > 0 else 0.0
    return v

# --- Interactive-like usage in a script/notebook ---
# Provide any free-form string here:
user_str = "Smiling, Blond_Hair, -Male, Young"  # edit this string freely

# Prepare batch attrs by repeating parsed vector
with torch.inference_mode():
    attr_vec = parse_freeform_attrs(user_str).unsqueeze(0)  # (1, attr_dim)
    B = 8
    attrs_batch = attr_vec.repeat(B, 1)                     # (B, attr_dim)
    z = torch.randn(B, z_dim, device=device)

    # Use the current G in memory (INT8 after Cell 12, or FP32 if placed after Cell 9)
    G.eval()
    gen_imgs = G(z, attrs_batch)  # [-1,1], shape (B,3,128,128)

# Visualize first 4 samples inline
plt.figure(figsize=(8,8))
for i in range(4):
    plt.subplot(2,2,i+1)
    plt.imshow(to_display(gen_imgs[i]))
    plt.title(user_str)
    plt.axis("off")
plt.show()

print("Generated shape:", tuple(gen_imgs.shape), "dtype:", gen_imgs.dtype)
