In [1]:
# simple_gan_mnist.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader

# ---------- параметры ----------
batch_size = 128
z_dim = 100               # размер латентного вектора (шум)
lr = 2e-4
n_epochs = 10
img_size = 28
img_channels = 1
sample_dir = "gan_samples"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(sample_dir, exist_ok=True)

# ---------- модель генератора ----------
class Generator(nn.Module):
    def __init__(self, z_dim, img_channels):
        super().__init__()
        self.net = nn.Sequential(
            # z -> 256
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 -> 512
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 512 -> 1024
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # 1024 -> img_size*img_size*channels
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()  # масштабируем в [-1, 1]
        )

    def forward(self, z):
        out = self.net(z)
        out = out.view(z.size(0), img_channels, img_size, img_size)
        return out

# ---------- модель дискриминатора ----------
class Discriminator(nn.Module):
    def __init__(self, img_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(img_channels * img_size * img_size, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()  # вероятность реального изображения
        )

    def forward(self, x):
        return self.net(x)

# ---------- данные (MNIST) ----------
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # в [-1, 1]
])

train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=transform)
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

# ---------- инициализация моделей, оптимизаторов, loss ----------
G = Generator(z_dim, img_channels).to(device)
D = Discriminator(img_channels).to(device)

criterion = nn.BCELoss()

opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# фиксированный шум для мониторинга прогресса
fixed_noise = torch.randn(64, z_dim, device=device)

# ---------- тренировка ----------
print(f"Начинаем обучение на устройстве: {device}")
for epoch in range(1, n_epochs + 1):
    epoch_loss_D = 0.0
    epoch_loss_G = 0.0
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)

        # --- Обновление дискриминатора: максимизируем log(D(x)) + log(1 - D(G(z)))
        batch_size_cur = real_imgs.size(0)
        real_labels = torch.ones(batch_size_cur, 1, device=device)
        fake_labels = torch.zeros(batch_size_cur, 1, device=device)

        # реальная часть
        D.zero_grad()
        outputs_real = D(real_imgs)
        loss_real = criterion(outputs_real, real_labels)

        # фейковая часть
        noise = torch.randn(batch_size_cur, z_dim, device=device)
        fake_imgs = G(noise).detach()  # detach: не обновлять G при шаге D
        outputs_fake = D(fake_imgs)
        loss_fake = criterion(outputs_fake, fake_labels)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        opt_D.step()

        # --- Обновление генератора: минимизируем log(1 - D(G(z))) <-> максимизируем log(D(G(z)))
        G.zero_grad()
        noise = torch.randn(batch_size_cur, z_dim, device=device)
        gen_imgs = G(noise)
        outputs = D(gen_imgs)
        # цель: дискриминатор считает фейки реальными
        loss_G = criterion(outputs, real_labels)
        loss_G.backward()
        opt_G.step()

        epoch_loss_D += loss_D.item()
        epoch_loss_G += loss_G.item()

    avg_loss_D = epoch_loss_D / len(dataloader)
    avg_loss_G = epoch_loss_G / len(dataloader)
    print(f"Epoch [{epoch}/{n_epochs}]  Loss_D: {avg_loss_D:.4f}  Loss_G: {avg_loss_G:.4f}")

    # сохраняем примеры через фиксированный шум
    with torch.no_grad():
        samples = G(fixed_noise).detach().cpu()
    # денормализация из [-1,1] -> [0,1]
    utils.save_image((samples + 1) / 2, os.path.join(sample_dir, f"epoch_{epoch:03d}.png"), nrow=8)

# сохраняем модели
torch.save(G.state_dict(), "generator.pth")
torch.save(D.state_dict(), "discriminator.pth")
print("Обучение завершено, модели и примеры сохранены.")


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.68MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 750kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.71MB/s]


Начинаем обучение на устройстве: cpu




Epoch [1/10]  Loss_D: 1.2333  Loss_G: 1.0301
Epoch [2/10]  Loss_D: 1.1835  Loss_G: 1.1048
Epoch [3/10]  Loss_D: 1.1213  Loss_G: 1.2030
Epoch [4/10]  Loss_D: 1.0870  Loss_G: 1.2257
Epoch [5/10]  Loss_D: 1.1209  Loss_G: 1.1569
Epoch [6/10]  Loss_D: 1.1549  Loss_G: 1.0837
Epoch [7/10]  Loss_D: 1.1681  Loss_G: 1.0634
Epoch [8/10]  Loss_D: 1.1889  Loss_G: 1.0303
Epoch [9/10]  Loss_D: 1.2001  Loss_G: 1.0144
Epoch [10/10]  Loss_D: 1.1995  Loss_G: 1.0151
Обучение завершено, модели и примеры сохранены.
