In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

import torch.optim as optim


from torchvision.utils import save_image
from torch.autograd.variable import Variable
from torch.utils.tensorboard import SummaryWriter

In [2]:
# Гиперпараметры
dataset_path = './Skins/'
img_channels = 3
latent_dim = 100
batch_size = 128

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

dataset = ImageFolder(root=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
# Функция для инициализации весов моделей
def initialize_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256 * 8 * 8),
            nn.BatchNorm1d(256 * 8 * 8),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)
    
class Discriminator(nn.Module):
    def __init__(self, img_channels):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)


In [5]:
# Инициализация моделей
generator = Generator(latent_dim, img_channels).to(device)
discriminator = Discriminator(img_channels).to(device)

# Применение инициализации весов
generator.apply(initialize_weights)
discriminator.apply(initialize_weights)

# Определение функций потерь и оптимизаторов
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [7]:
def train_gan(generator, discriminator, optimizer_G, optimizer_D, adversarial_loss, dataloader, device, epochs=100):
    writer = SummaryWriter()

    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.size(0)

            optimizer_D.zero_grad()

            real_imgs = imgs.to(device)
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Обучение на реальных изображениях
            real_outputs = discriminator(real_imgs)
            d_loss_real = adversarial_loss(real_outputs, real_labels)
            d_loss_real.backward()

            # Обучение на сгенерированных изображениях
            z = Variable(torch.randn(batch_size, latent_dim)).to(device)
            fake_imgs = generator(z)
            fake_outputs = discriminator(fake_imgs.detach())
            d_loss_fake = adversarial_loss(fake_outputs, fake_labels)
            d_loss_fake.backward()

            d_loss = d_loss_real + d_loss_fake
            optimizer_D.step()

            optimizer_G.zero_grad()

            # Генерация изображений и подача их в дискриминатор
            fake_outputs = discriminator(fake_imgs)
            g_loss = adversarial_loss(fake_outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

            if i % 700 == 0:
                print(
                    f"Epoch [{epoch}/{epochs}], Batch Step [{i}/{len(dataloader)}], "
                    f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}"
                )

                # Запись в TensorBoard
                writer.add_scalar("Generator Loss", g_loss.item(), epoch * len(dataloader) + i)
                writer.add_scalar("Discriminator Loss", d_loss.item(), epoch * len(dataloader) + i)

        # Сохранение сгенерированных изображений на каждой эпохе
        save_image(fake_imgs.data[:25], f"images/generated_{epoch}.png", nrow=5, normalize=True)

    writer.close()

train_gan(generator, discriminator, optimizer_G, optimizer_D, adversarial_loss, dataloader, device, epochs=5)


Epoch [0/5], Batch Step [0/7446], D Loss: 0.5946, G Loss: 12.4421
Epoch [0/5], Batch Step [700/7446], D Loss: 0.0524, G Loss: 4.4863
Epoch [0/5], Batch Step [1400/7446], D Loss: 3.8264, G Loss: 16.7395
Epoch [0/5], Batch Step [2100/7446], D Loss: 0.0704, G Loss: 4.5511
Epoch [0/5], Batch Step [2800/7446], D Loss: 0.0832, G Loss: 6.1850
Epoch [0/5], Batch Step [3500/7446], D Loss: 0.1049, G Loss: 4.2233
Epoch [0/5], Batch Step [4200/7446], D Loss: 0.3697, G Loss: 11.1760
Epoch [0/5], Batch Step [4900/7446], D Loss: 0.0527, G Loss: 4.6337
Epoch [0/5], Batch Step [5600/7446], D Loss: 0.0705, G Loss: 4.1410
Epoch [0/5], Batch Step [6300/7446], D Loss: 0.0377, G Loss: 4.6106
Epoch [0/5], Batch Step [7000/7446], D Loss: 0.0295, G Loss: 5.4440
Epoch [1/5], Batch Step [0/7446], D Loss: 0.1228, G Loss: 8.1262
Epoch [1/5], Batch Step [700/7446], D Loss: 0.0237, G Loss: 5.6537
Epoch [1/5], Batch Step [1400/7446], D Loss: 0.0263, G Loss: 5.3088
Epoch [1/5], Batch Step [2100/7446], D Loss: 0.0112, 

In [10]:
# Путь к файлу модели
checkpoint_path = '5_model.pth'

# Создание словаря для сохранения состояний
checkpoint = {
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_G_state_dict': optimizer_G.state_dict(),
    'optimizer_D_state_dict': optimizer_D.state_dict()
}

# Сохранение словаря в файл
torch.save(checkpoint, checkpoint_path)
print("Модели и оптимизаторы сохранены в", checkpoint_path)


Модели и оптимизаторы сохранены в 5_model.pth


In [9]:
# Загрузка словаря из файла
checkpoint = torch.load(checkpoint_path)

# Установка сохраненных параметров в модели и оптимизаторы
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])

# Продолжение обучения
train_gan(generator, discriminator, optimizer_G, optimizer_D, adversarial_loss, dataloader, device, epochs=5)


Epoch [0/5], Batch Step [0/7446], D Loss: 0.0020, G Loss: 8.1325
Epoch [0/5], Batch Step [700/7446], D Loss: 0.0064, G Loss: 7.3075
Epoch [0/5], Batch Step [1400/7446], D Loss: 0.0032, G Loss: 6.3493
Epoch [0/5], Batch Step [2100/7446], D Loss: 0.0057, G Loss: 6.8621
Epoch [0/5], Batch Step [2800/7446], D Loss: 0.0056, G Loss: 7.2149
Epoch [0/5], Batch Step [3500/7446], D Loss: 0.0327, G Loss: 4.6347
Epoch [0/5], Batch Step [4200/7446], D Loss: 0.0166, G Loss: 6.1404
Epoch [0/5], Batch Step [4900/7446], D Loss: 0.0042, G Loss: 7.0993
Epoch [0/5], Batch Step [5600/7446], D Loss: 0.0027, G Loss: 7.9843
Epoch [0/5], Batch Step [6300/7446], D Loss: 0.0290, G Loss: 5.7527
Epoch [0/5], Batch Step [7000/7446], D Loss: 0.0056, G Loss: 6.9116
Epoch [1/5], Batch Step [0/7446], D Loss: 0.1189, G Loss: 11.5759
Epoch [1/5], Batch Step [700/7446], D Loss: 0.0026, G Loss: 6.9711
Epoch [1/5], Batch Step [1400/7446], D Loss: 0.0009, G Loss: 7.8517
Epoch [1/5], Batch Step [2100/7446], D Loss: 0.0017, G 