<a href="https://colab.research.google.com/github/ViRiver24/Lesson-8/blob/main/Lesson-8%20CGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# 1. Завантаження та попередня обробка даних
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # Нормалізація до [-1, 1]
])

dataset = datasets.CIFAR10(root='./data', download=True, transform=transform)
data_loader = DataLoader(dataset, batch_size=128, shuffle=True)

# 2. Налаштування архітектури Conditional GAN
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_shape):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )
        self.img_shape = img_shape

    def forward(self, noise, labels):
        gen_input = torch.cat((noise, self.label_embedding(labels)), dim=1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, num_classes, img_shape):
        super(Discriminator, self).__init__()
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(num_classes + int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        self.img_shape = img_shape

    def forward(self, img, labels):
        d_input = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), dim=1)
        validity = self.model(d_input)
        return validity

# Гіперпараметри
latent_dim = 100
num_classes = 10
img_shape = (3, 32, 32)
batch_size = 128
lr = 0.0002

# Використовуємо CPU
device = torch.device("cpu")

generator = Generator(latent_dim, num_classes, img_shape).to(device)
discriminator = Discriminator(num_classes, img_shape).to(device)

adversarial_loss = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 4. Навчання
os.makedirs("generated_images", exist_ok=True)

for epoch in range(50):
    for i, (imgs, labels) in enumerate(data_loader):
        batch_size = imgs.size(0)

        # Метки
        valid = torch.ones((batch_size, 1), device=device, requires_grad=False)
        fake = torch.zeros((batch_size, 1), device=device, requires_grad=False)

        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # Навчання генератора
        optimizer_G.zero_grad()
        noise = torch.randn((batch_size, latent_dim), device=device)
        gen_labels = torch.randint(0, num_classes, (batch_size,), device=device)
        gen_imgs = generator(noise, gen_labels)
        g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)
        g_loss.backward()
        optimizer_G.step()

        # Навчання дискримінатора
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        print(f"[Epoch {epoch}/{50}] [Batch {i}/{len(data_loader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

    save_image(gen_imgs.data[:25], f"generated_images/{epoch}.png", nrow=5, normalize=True)

# 5. Оцінка
print("Навчання завершено!")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[Epoch 37/50] [Batch 84/391] [D loss: 0.6383829116821289] [G loss: 0.9375270009040833]
[Epoch 37/50] [Batch 85/391] [D loss: 0.6012446880340576] [G loss: 0.9132484197616577]
[Epoch 37/50] [Batch 86/391] [D loss: 0.6404231786727905] [G loss: 0.848567008972168]
[Epoch 37/50] [Batch 87/391] [D loss: 0.6329219341278076] [G loss: 0.944403350353241]
[Epoch 37/50] [Batch 88/391] [D loss: 0.6304757595062256] [G loss: 0.8968034386634827]
[Epoch 37/50] [Batch 89/391] [D loss: 0.6525897979736328] [G loss: 0.8701943159103394]
[Epoch 37/50] [Batch 90/391] [D loss: 0.6169339418411255] [G loss: 0.90693598985672]
[Epoch 37/50] [Batch 91/391] [D loss: 0.6065130233764648] [G loss: 0.9029984474182129]
[Epoch 37/50] [Batch 92/391] [D loss: 0.6340841054916382] [G loss: 0.8897964358329773]
[Epoch 37/50] [Batch 93/391] [D loss: 0.6277996301651001] [G loss: 0.9249264597892761]
[Epoch 37/50] [Batch 94/391] [D loss: 0.623898446559906] [G loss: 0.9