In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Subset

In [3]:
LATENT_DIM = 128
Z1_DIM = 32  
Z2_DIM = 32  
Z3_DIM = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [4]:
def show_images(img_tensor, title=""):
    img_tensor = img_tensor * 0.5 + 0.5  
    grid = make_grid(img_tensor.detach().cpu(), nrow=8)
    plt.figure(figsize=(12, 6))
    plt.imshow(np.transpose(grid.numpy(), (1, 2, 0)))
    plt.title(title)
    plt.axis('off')
    plt.show()

In [5]:
class Generator(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128 * 4 * 4),
            nn.ReLU(True),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),   # 16x16
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),   # 32x32
            nn.BatchNorm2d(16), nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 4, 2, 1),    # 64x64
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(16, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Flatten(), nn.Linear(128 * 4 * 4, 1)
        )

    def forward(self, x):
        return self.model(x)

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, LATENT_DIM)
        )

    def forward(self, x):
        return self.net(x)


In [7]:
def gan_loss(real_out, fake_out):
    return -torch.mean(real_out) + torch.mean(fake_out)

def contrastive_loss(z1, z2, temperature=0.07):
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    logits = torch.matmul(z1, z2.T) / temperature
    labels = torch.arange(z1.size(0), device=z1.device)
    return F.cross_entropy(logits, labels)

In [12]:
transform = transforms.Compose([
transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
dataset = datasets.CelebA(root="../data", split="train", target_type="attr", download=True, transform=transform)
smiling_idx, bald_idx = 31, 4
indices = [i for i, attr in enumerate(dataset.attr) if attr[smiling_idx] == 1 or attr[bald_idx] == 1]
subset = Subset(dataset, indices)
dataloader = DataLoader(subset, batch_size=32, shuffle=True, num_workers=2)

FileURLRetrievalError: Failed to retrieve file url:

	Too many users have viewed or downloaded this file recently. Please
	try accessing the file again later. If the file you are trying to
	access is particularly large or is shared with many people, it may
	take up to 24 hours to be able to view or download the file. If you
	still can't access a file after 24 hours, contact your domain
	administrator.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM

but Gdown can't. Please check connections and permissions.

In [None]:
G = Generator(LATENT_DIM).to(DEVICE)
D1 = Discriminator().to(DEVICE)
D2 = Discriminator().to(DEVICE)
D3 = Discriminator().to(DEVICE)

opt_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(list(D1.parameters()) + list(D2.parameters()) + list(D3.parameters()), lr=2e-4, betas=(0.5, 0.999))

for epoch in range(1):  # set higher for full training
    for real_images, attrs in dataloader:
        real_images = real_images.to(DEVICE)
        x1 = real_images[attrs[:, 31] == 1]  # smiling
        x2 = real_images[attrs[:, 4] == 1]   # bald
        if len(x1) == 0 or len(x2) == 0: continue

        z = torch.randn(x1.size(0), LATENT_DIM).to(DEVICE)
        z1, z2, z3 = z[:, :Z1_DIM], z[:, Z1_DIM:Z1_DIM+Z2_DIM], z[:, -Z3_DIM:]

        fake_x1 = G(torch.cat([z1, torch.zeros_like(z2), z3], dim=1))
        fake_x2 = G(torch.cat([torch.zeros_like(z1), z2, z3], dim=1))
        fake_xy = G(torch.cat([z1, z2, z3], dim=1))

        D1.zero_grad(), D2.zero_grad(), D3.zero_grad()
        loss_d1 = gan_loss(D1(x1[:len(fake_x1)]), D1(fake_x1.detach()))
        loss_d2 = gan_loss(D2(x2[:len(fake_x2)]), D2(fake_x2.detach()))
        loss_d3 = gan_loss(D3(real_images[:len(fake_xy)]), D3(fake_xy.detach()))
        (loss_d1 + loss_d2 + loss_d3).backward()
        opt_D.step()

        G.zero_grad()
        loss_g = -D1(fake_x1).mean() - D2(fake_x2).mean() - D3(fake_xy).mean()
        loss_g.backward()
        opt_G.step()

    print(f"Epoch {epoch}: D Loss = {loss_d1.item() + loss_d2.item() + loss_d3.item():.4f}, G Loss = {loss_g.item():.4f}")


FileURLRetrievalError: Failed to retrieve file url:

	Too many users have viewed or downloaded this file recently. Please
	try accessing the file again later. If the file you are trying to
	access is particularly large or is shared with many people, it may
	take up to 24 hours to be able to view or download the file. If you
	still can't access a file after 24 hours, contact your domain
	administrator.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM

but Gdown can't. Please check connections and permissions.

In [None]:

z = torch.randn(16, LATENT_DIM).to(DEVICE)
z1, z2, z3 = z[:, :Z1_DIM], z[:, Z1_DIM:Z1_DIM+Z2_DIM], z[:, -Z3_DIM:]
img_x1 = G(torch.cat([z1, torch.zeros_like(z2), z3], dim=1))
img_x2 = G(torch.cat([torch.zeros_like(z1), z2, z3], dim=1))
img_xy = G(torch.cat([z1, z2, z3], dim=1))

show_images(img_x1, title="Generated Domain X1 (e.g., Smiling)")
show_images(img_x2, title="Generated Domain X2 (e.g., Bald)")
show_images(img_xy, title="Generated Intersection X1∩X2")
