<a href="https://colab.research.google.com/github/ViRiver24/Lesson-8/blob/main/Lesson%208%20DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

# Завантаження CIFAR-10 з попередньою обробкою
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Нормалізація до [-1, 1]
])

batch_size = 64
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:01<00:00, 90.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [2]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim, channels_img):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Шар 1
            nn.ConvTranspose2d(latent_dim, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # Шар 2
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # Шар 3
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # Шар 4
            nn.ConvTranspose2d(128, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Нормалізація до [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels_img):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Шар 1
            nn.Conv2d(channels_img, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            # Шар 2
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # Шар 3
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            # Шар 4
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()  # Ймовірність "реальності"
        )

    def forward(self, x):
        return self.disc(x)

In [4]:
import torch.optim as optim

# Параметри
latent_dim = 100
channels_img = 3
learning_rate = 0.0002
beta1 = 0.5

# Ініціалізація
device = "cuda" if torch.cuda.is_available() else "cpu"
gen = Generator(latent_dim, channels_img).to(device)
disc = Discriminator(channels_img).to(device)

# Оптимізатори
opt_gen = optim.Adam(gen.parameters(), lr=learning_rate, betas=(beta1, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=learning_rate, betas=(beta1, 0.999))

# Функція втрат
criterion = nn.BCELoss()

In [8]:
import numpy as np
from torchvision.utils import save_image

epochs = 50
fixed_noise = torch.randn(64, latent_dim, 1, 1).to(device)

for epoch in range(epochs):
    for i, (real, _) in enumerate(trainloader):
        real = real.to(device)
        batch_size = real.size(0)

        # Мітки
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        ### Навчання дискримінатора ###
        noise = torch.randn(batch_size, latent_dim, 1, 1).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1,1)
        loss_disc_real = criterion(disc_real, real_labels)
        disc_fake = disc(fake.detach()).view(-1,1)
        loss_disc_fake = criterion(disc_fake, fake_labels)
        loss_disc = (loss_disc_real + loss_disc_fake) / 2

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Навчання генератора ###
        output = disc(fake).view(-1,1)
        loss_gen = criterion(output, real_labels)

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

    # Збереження зображень для оцінки
    if epoch % 10 == 0:
        with torch.no_grad():
            fake = gen(fixed_noise)
            save_image(fake * 0.5 + 0.5, f"fake_images_epoch_{epoch}.png")

KeyboardInterrupt: 