In [32]:
import os
from pathlib import Path
from tqdm import tqdm

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image

In [34]:
# Параметры 
IMAGE_SIZE = 64
BATCH = 128
LATENT_DIM = 256
NUM_EPOCHS_GAN = 100
LR = 2e-4
SAMPLE_PER_CLASS = 4
NUM_CLASSES = 5

In [35]:
# Путь к данным
DATA_ROOT = Path("data/animals")
OUT_DIR = Path("outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pin_memory = True if torch.cuda.is_available() else False

In [None]:
# Трансформы и DataLoader 
tf = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = datasets.ImageFolder(root=str(DATA_ROOT), transform=tf)

if len(dataset) == 0:
    raise RuntimeError(f"No images found in {DATA_ROOT}. Проверьте, что данные распакованы в data/animals/<class>/*.jpg")

class_names = dataset.classes
NUM_CLASSES = len(class_names)
print("Found classes:", NUM_CLASSES)

loader = DataLoader(dataset, batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=pin_memory, drop_last=True)

Found classes: 5


In [38]:
# Утилиты 
def save_sample_grid(tensor, path, nrow=8):
    tensor = tensor.clamp(-1, 1)
    save_image((tensor + 1) / 2, path, nrow=nrow)

def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        if m.weight is not None:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        if m.weight is not None:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        if m.weight is not None:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

def one_hot(labels, C):
    return torch.zeros(labels.size(0), C, device=labels.device).scatter_(1, labels.unsqueeze(1), 1)

In [None]:
# Conditional GAN модели 
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, n_classes, embed_dim=128):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.embed = nn.Embedding(n_classes, embed_dim)
        self.fc_gamma = nn.Linear(embed_dim, num_features)
        self.fc_beta = nn.Linear(embed_dim, num_features)
        nn.init.zeros_(self.fc_gamma.weight); nn.init.zeros_(self.fc_gamma.bias)
        nn.init.zeros_(self.fc_beta.weight);  nn.init.zeros_(self.fc_beta.bias)

    def forward(self, x, labels):
        out = self.bn(x)
        e = self.embed(labels)
        gamma = self.fc_gamma(e).unsqueeze(2).unsqueeze(3)
        beta  = self.fc_beta(e).unsqueeze(2).unsqueeze(3)
        return out * (1 + gamma) + beta

from torch.nn.utils import spectral_norm

class CGenerator(nn.Module):
    def __init__(self, z_dim, n_classes, img_channels=3, base=64, embed_dim=128):
        super().__init__()
        self.base = base
        self.label_emb = nn.Embedding(n_classes, embed_dim)

        input_dim = z_dim + embed_dim
        self.fc = nn.Linear(input_dim, base*8*4*4)

        self.deconv1 = nn.ConvTranspose2d(base*8, base*4, 4, 2, 1, bias=False)
        self.cbn1 = ConditionalBatchNorm2d(base*4, n_classes, embed_dim)
        self.deconv2 = nn.ConvTranspose2d(base*4, base*2, 4, 2, 1, bias=False)
        self.cbn2 = ConditionalBatchNorm2d(base*2, n_classes, embed_dim)
        self.deconv3 = nn.ConvTranspose2d(base*2, base, 4, 2, 1, bias=False)
        self.cbn3 = ConditionalBatchNorm2d(base, n_classes, embed_dim)

        self.final = nn.ConvTranspose2d(base, img_channels, 4, 2, 1)

    def forward(self, z, labels):
        l = self.label_emb(labels)
        x = torch.cat([z, l], dim=1)
        x = self.fc(x)
        x = x.view(-1, self.base*8, 4, 4)

        x = F.relu(self.cbn1(self.deconv1(x), labels))
        x = F.relu(self.cbn2(self.deconv2(x), labels))
        x = F.relu(self.cbn3(self.deconv3(x), labels))
        x = torch.tanh(self.final(x))
        return x

class CDiscriminator(nn.Module):
    def __init__(self, n_classes, img_channels=3, base=64):
        super().__init__()
        self.features = nn.Sequential(
            spectral_norm(nn.Conv2d(img_channels, base, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(base, base*2, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(base*2, base*4, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(base*4, base*8, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        self.fc_adv = spectral_norm(nn.Linear(base*8, 1))
        self.fc_cls = spectral_norm(nn.Linear(base*8, n_classes))

    def forward(self, x):
        feat = self.features(x).view(x.size(0), -1)
        adv_out = self.fc_adv(feat).squeeze(1)
        cls_out = self.fc_cls(feat)
        return adv_out, cls_out

In [40]:
def d_hinge_loss(real_logits, fake_logits):
    loss_real = torch.mean(F.relu(1.0 - real_logits))
    loss_fake = torch.mean(F.relu(1.0 + fake_logits))
    return loss_real + loss_fake

def g_hinge_loss(fake_logits):
    return -torch.mean(fake_logits)

def r1_penalty(real_img, real_adv):
    grad_real = torch.autograd.grad(
        outputs=real_adv.sum(),
        inputs=real_img,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    return (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()

def add_instance_noise(x, std=0.05):
    if std > 0:
        noise = torch.randn_like(x) * std
        return x + noise
    return x

In [41]:
#  Инициализация моделей, оптимизаторы 
G = CGenerator(LATENT_DIM, NUM_CLASSES).to(device)
D = CDiscriminator(NUM_CLASSES).to(device)

optG = torch.optim.Adam(G.parameters(), lr=5e-5, betas=(0.5,0.999))
optD = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))

cls_loss = nn.CrossEntropyLoss()

fixed_z = torch.randn(NUM_CLASSES * 8, LATENT_DIM, device=device)
fixed_labels = torch.tensor([i for i in range(NUM_CLASSES) for _ in range(8)], device=device)

In [None]:
# Тренировочный цикл
def train_cgan(loader, epochs=100, noise_std=0.05, out_dir=OUT_DIR):
    G.train(); D.train()
    for epoch in range(epochs):
        pbar = tqdm(loader, desc=f"cGAN epoch {epoch+1}/{epochs}")
        for imgs, labs in pbar:
            imgs = imgs.to(device); labs = labs.to(device)
            bs = imgs.size(0)

            # Discriminator step
            for _ in range(2):  # d_steps=2
                z = torch.randn(bs, LATENT_DIM, device=device)
                fake = G(z, labs)
                imgs.requires_grad_(True) 
                real_adv, real_cls = D(add_instance_noise(imgs, noise_std))
                fake_adv, fake_cls = D(add_instance_noise(fake.detach(), noise_std))

                lossD_adv = d_hinge_loss(real_adv, fake_adv)
                lossD_cls = cls_loss(real_cls, labs)
                lossD_r1 = r1_penalty(imgs, real_adv) * 10.0  
                lossD = lossD_adv + lossD_cls + lossD_r1

                optD.zero_grad(); lossD.backward(); optD.step()


            # Generator step
            for _ in range(1):  
                z = torch.randn(bs, LATENT_DIM, device=device)
                fake = G(z, labs)
                fake_adv, fake_cls = D(add_instance_noise(fake, noise_std))
                lossG_adv = g_hinge_loss(fake_adv)
                lossG_cls = cls_loss(fake_cls, labs)
                lossG = lossG_adv + lossG_cls

                optG.zero_grad(); lossG.backward(); optG.step()

            pbar.set_postfix({'lossD': float(lossD.item()), 'lossG': float(lossG.item())})

        G.eval()
        with torch.no_grad():
            # Fixed noise (consistency)
            sample_fixed = G(fixed_z, fixed_labels).cpu()
            save_sample_grid(sample_fixed, f"{out_dir}/cgan_epoch{epoch+1}_fixed.png", nrow=8)

            # Random noise (diversity check)
            z = torch.randn(NUM_CLASSES * 8, LATENT_DIM, device=device)
            labels = torch.tensor([i for i in range(NUM_CLASSES) for _ in range(8)], device=device)
            sample_rand = G(z, labels).cpu()
            save_sample_grid(sample_rand, f"{out_dir}/cgan_epoch{epoch+1}_random.png", nrow=8)
        G.train()

    torch.save(G.state_dict(), f"{out_dir}/cgan_G.pth")
    torch.save(D.state_dict(), f"{out_dir}/cgan_D.pth")
    print("Saved models to", out_dir)

In [46]:
# Генерация по имени класса 
def class_name_to_index(name):
    lowered = [c.lower() for c in class_names]
    name_l = name.lower()
    if name_l in lowered:
        return lowered.index(name_l)
    for i, c in enumerate(lowered):
        if c.startswith(name_l) or name_l in c:
            return i
    raise ValueError("class not found")

def generate_cgan(breed_name, n=4, save_path=None):
    idx = class_name_to_index(breed_name)
    z = torch.randn(n, LATENT_DIM, device=device)
    labels = torch.tensor([idx] * n, device=device)
    G.eval()
    with torch.no_grad():
        imgs = G(z, labels).cpu()
    if save_path:
        save_sample_grid(imgs, save_path, nrow=n)
    return imgs

In [49]:
train_cgan(loader, epochs=NUM_EPOCHS_GAN)

cGAN epoch 1/100: 100%|██████████| 105/105 [00:22<00:00,  4.59it/s, lossD=2.49, lossG=2.2]   
cGAN epoch 2/100: 100%|██████████| 105/105 [00:23<00:00,  4.55it/s, lossD=2.18, lossG=1.34]  
cGAN epoch 3/100: 100%|██████████| 105/105 [00:21<00:00,  4.79it/s, lossD=2.3, lossG=0.543]   
cGAN epoch 4/100: 100%|██████████| 105/105 [00:23<00:00,  4.52it/s, lossD=2.17, lossG=1.21]  
cGAN epoch 5/100: 100%|██████████| 105/105 [00:23<00:00,  4.38it/s, lossD=2.26, lossG=1.2]   
cGAN epoch 6/100: 100%|██████████| 105/105 [00:21<00:00,  4.80it/s, lossD=2.45, lossG=1.41]   
cGAN epoch 7/100: 100%|██████████| 105/105 [00:35<00:00,  2.99it/s, lossD=2.04, lossG=1.3]   
cGAN epoch 8/100: 100%|██████████| 105/105 [00:47<00:00,  2.20it/s, lossD=2.23, lossG=1.39]  
cGAN epoch 9/100: 100%|██████████| 105/105 [00:27<00:00,  3.86it/s, lossD=1.96, lossG=1.04]  
cGAN epoch 10/100: 100%|██████████| 105/105 [00:44<00:00,  2.34it/s, lossD=2.29, lossG=0.144]  
cGAN epoch 11/100: 100%|██████████| 105/105 [00:40<00:00

Saved models to outputs





In [51]:
generate_cgan("dog", n = 1, save_path=OUT_DIR / "generation_dog.png")
generate_cgan("lion", n = 1, save_path=OUT_DIR / "generation_lion.png")
generate_cgan("cat", n = 1, save_path=OUT_DIR / "generation_cat.png")


tensor([[[[-0.2209, -0.2939, -0.3764,  ..., -0.7371, -0.6452, -0.5424],
          [-0.3769, -0.3045, -0.3626,  ..., -0.7341, -0.6815, -0.5002],
          [-0.3659, -0.1786, -0.1979,  ..., -0.4930, -0.4037, -0.4779],
          ...,
          [ 0.0769,  0.4897,  0.5567,  ..., -0.6894, -0.7427, -0.6826],
          [-0.0132,  0.3099,  0.3503,  ..., -0.6484, -0.6385, -0.7146],
          [-0.0366, -0.0117,  0.0284,  ..., -0.7902, -0.6847, -0.5509]],

         [[-0.2958, -0.3286, -0.4163,  ..., -0.7170, -0.6352, -0.5236],
          [-0.4892, -0.3471, -0.3526,  ..., -0.7495, -0.6043, -0.5995],
          [-0.3857, -0.2064, -0.3254,  ..., -0.5928, -0.5159, -0.5104],
          ...,
          [-0.0296,  0.3346,  0.5569,  ..., -0.7462, -0.7754, -0.7273],
          [-0.0726,  0.2254,  0.4042,  ..., -0.7169, -0.7216, -0.7402],
          [-0.0597, -0.1403, -0.0203,  ..., -0.8036, -0.6572, -0.5635]],

         [[-0.3007, -0.5153, -0.4602,  ..., -0.8684, -0.8022, -0.6077],
          [-0.4827, -0.5479, -