In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CelebA
import matplotlib.pyplot as plt
import os

# Устройство для обучения
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Параметры
batch_size = 64
image_size = 64
latent_dim = 100
num_attributes = 40  # Количество атрибутов в CelebA

# Трансформации изображений
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Загрузка CelebA датасета
train_dataset = CelebA(root='./data', split='train', download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Проверка загрузки данных
print(f"Number of training samples: {len(train_dataset)}")


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_attributes, img_size=64, channels=3):
        super(Generator, self).__init__()
        self.img_size = img_size
        self.channels = channels
        self.latent_dim = latent_dim
        self.num_attributes = num_attributes

        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_attributes, 128 * 8 * 4 * 4),
            nn.ReLU(True),
            nn.Unflatten(1, (128 * 8, 4, 4)),
            nn.ConvTranspose2d(128 * 8, 128 * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(128 * 4, 128 * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(128 * 2, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, a):
        # Concatenate latent vector and attribute vector
        input = torch.cat((z, a), 1)
        img = self.model(input)
        return img


In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_size=64, channels=3, num_attributes=40):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.channels = channels
        self.num_attributes = num_attributes

        self.model = nn.Sequential(
            nn.Conv2d(channels + num_attributes, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, img, a):
        # Reshape attribute vector to feature map
        a = a.view(a.size(0), a.size(1), 1, 1)
        a = a.repeat(1, 1, img.size(2), img.size(3))
        # Concatenate attribute map with image
        img = torch.cat((img, a), 1)
        validity = self.model(img)
        return validity


In [None]:
# Инициализация моделей
generator = Generator(latent_dim, num_attributes).to(device)
discriminator = Discriminator(img_size=image_size, channels=3, num_attributes=num_attributes).to(device)

# Оптимизаторы
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Функция потерь
adversarial_loss = nn.BCELoss()

# Функция для сохранения изображений
def save_images(epoch, path='images'):
    os.makedirs(path, exist_ok=True)
    z = torch.randn(batch_size, latent_dim).to(device)
    a = torch.randint(0, 2, (batch_size, num_attributes)).float().to(device)
    gen_imgs = generator(z, a)
    gen_imgs = gen_imgs.view(gen_imgs.size(0), 3, image_size, image_size)
    gen_imgs = (gen_imgs * 0.5) + 0.5  # Unnormalize
    torchvision.utils.save_image(gen_imgs, f"{path}/{epoch}.png", nrow=8, normalize=True)

# Функция для обучения
def train_cgan(generator, discriminator, train_loader, optimizer_G, optimizer_D, adversarial_loss, num_epochs=50):
    generator.train()
    discriminator.train()

    G_losses = []
    D_losses = []

    for epoch in range(num_epochs):
        for i, (imgs, attrs) in enumerate(train_loader):
            imgs = imgs.to(device)
            attrs = attrs.float().to(device)
            batch_size = imgs.size(0)

            # Adversarial ground truths
            valid = torch.ones(batch_size, 1).to(device)
            fake = torch.zeros(batch_size, 1).to(device)

            # Train Discriminator
            optimizer_D.zero_grad()

            # Sample noise and labels as generator input
            z = torch.randn(batch_size, latent_dim).to(device)
            gen_imgs = generator(z, attrs)

            # Loss for real images
            real_loss = adversarial_loss(discriminator(imgs, attrs), valid)
            # Loss for fake images
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), attrs), fake)
            # Total discriminator loss
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator(gen_imgs, attrs), valid)

            g_loss.backward()
            optimizer_G.step()

            # Print loss values
            if i % 100 == 0:
                print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(train_loader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

        # Save losses
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

        # Save images
        save_images(epoch)
