### WGAN

* Модифицируйте код ячеек ниже и реализуйте [Wasserstein GAN](https://arxiv.org/abs/1701.07875) с клиппингом весов. (10 баллов)

* Замените клиппинг весов на [штраф градентов](https://arxiv.org/pdf/1704.00028v3.pdf). (10 баллов)

* Добавьте лейблы в WGAN, тем самым решая задачу [условной генерации](https://arxiv.org/pdf/1411.1784.pdf). (30 баллов)

Добавьте в этот файл анализ полученных результатов с различными графиками обучения и визуализацию генерации. Сравните как работает клиппинг весов и штраф градиентов и попробуйте пронаблюдать какие недостатки имеет модель GAN.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import matplotlib.pyplot as plt
import numpy as np

from torch.autograd import Variable

### Простой конфиг (для хранения параметров, можете использовать и модифицировать)

In [2]:
class Config:
    pass

config = Config()
config.mnist_path = None
config.batch_size = 16
config.num_workers = 3
config.num_epochs = 10
config.noise_size = 50
config.print_freq = 500

### Создаем dataloader

In [3]:
train = torchvision.datasets.FashionMNIST("fashion_mnist", train=True, transform=torchvision.transforms.ToTensor(), download=True)

In [4]:
dataloader = DataLoader(train, batch_size=16, shuffle=True)
len(dataloader)

3750

In [5]:
image, label = next(iter(dataloader))
image.size()

torch.Size([16, 1, 28, 28])

### Создаем модель GAN

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(config.noise_size, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 28*28),
            nn.Sigmoid())

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 50),
            nn.ReLU(inplace=True),
            nn.Linear(50, 1),
            nn.Sigmoid())
    def forward(self, x):
        return self.model(x)

In [7]:
generator = Generator()
discriminator = Discriminator()

### Оптимизатор и функция потерь

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = generator.to(device)
discriminator = discriminator.to(device)
optim_G = optim.Adam(params=generator.parameters(), lr=0.0001)
optim_D = optim.Adam(params=discriminator.parameters(), lr=0.0001)

criterion = nn.BCELoss()

### Для оптимизации процесса обучения можно заранее определить переменные и заполнять их значения новыми данными

In [9]:
noise = torch.randn(config.batch_size, config.noise_size, device=device)
# fixed_noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size, device=device).normal_(0, 1))
label = torch.empty(config.batch_size, device=device)
real_label = 1
fake_label = 0

RuntimeError: legacy constructor expects device type: cpu but device type: cuda was passed

### GAN обучение

In [None]:
ERRD_x = np.zeros(config.num_epochs)
ERRD_z = np.zeros(config.num_epochs)
ERRG = np.zeros(config.num_epochs)
N = len(dataloader)

for epoch in range(config.num_epochs):
    for iteration, (images, cat) in enumerate(dataloader):
        #######
        # Discriminator stage: maximize log(D(x)) + log(1 - D(G(z)))
        #######
        discriminator.zero_grad()

        # real
        label.data.fill_(real_label)
        input_data = images.view(images.shape[0], -1).to(device)
        output = discriminator(input_data).view(-1)
        errD_x = criterion(output, label)
        ERRD_x[epoch] += errD_x.item()
        errD_x.backward()

        # fake
        noise.data.normal_(0, 1)
        fake = generator(noise)
        label.data.fill_(fake_label)
        output = discriminator(fake.detach()).view(-1)
        errD_z = criterion(output, label)
        ERRD_z[epoch] += errD_z.item()
        errD_z.backward()

        optim_D.step()

        #######
        # Generator stage: maximize log(D(G(x))
        #######
        generator.zero_grad()
        label.data.fill_(real_label)
        output = discriminator(fake).view(-1)
        errG = criterion(output, label)
        ERRG[epoch] += errG.item()
        errG.backward()

        optim_G.step()

        if (iteration+1) % config.print_freq == 0:
            print('Epoch:{} Iter: {} errD_x: {:.2f} errD_z: {:.2f} errG: {:.2f}'.format(epoch+1,
                                                                                            iteration+1,
                                                                                            errD_x.item(),
                                                                                            errD_z.item(),
                                                                                            errG.item()))

In [None]:
noise.data.normal_(0, 1)
fake = generator(noise)

plt.figure(figsize=(6, 7))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(fake[i].detach().numpy().reshape(28, 28), cmap=plt.cm.Greys_r)
    plt.axis('off')

# Модифицируйте код ячеек ниже и реализуйте Wasserstein GAN с клиппингом весов. (10 баллов)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
import numpy as np

# Определение устройства (GPU или CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Конфигурация
class Config:
    batch_size = 16
    noise_size = 50
    num_epochs = 10
    clip_value = 0.01  # Предел для клиппинга весов дискриминатора
    learning_rate = 0.0002

config = Config()

# Загрузка данных
train_dataset = torchvision.datasets.FashionMNIST(
    root='./fashion_mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

# Определение Генератора
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(config.noise_size, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 28*28),
            nn.Tanh())

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# Определение Дискриминатора
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1))

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Инициализация моделей
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Оптимизаторы
optimizer_G = optim.RMSprop(generator.parameters(), lr=config.learning_rate)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=config.learning_rate)

# Тренировка
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
n_critic = 5

for epoch in range(config.num_epochs):
    for i, (imgs, _) in enumerate(train_loader):

        # Настоящие и поддельные изображения
        real_imgs = Variable(imgs.type(Tensor)).to(device)
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], config.noise_size)))).to(device)

        # Тренировка дискриминатора
        optimizer_D.zero_grad()

        # Генерация поддельных изображений
        fake_imgs = generator(z)

        d_loss_real = -torch.mean(discriminator(real_imgs))
        d_loss_fake = torch.mean(discriminator(fake_imgs.detach()))  # Detach fake_imgs для предотвращения повторного использования в генераторе
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Клиппинг весов дискриминатора
        for p in discriminator.parameters():
            p.data.clamp_(-config.clip_value, config.clip_value)

        # Тренировка генератора каждые n_critic итераций
        if i % n_critic == 0:
            optimizer_G.zero_grad()
            # Генерируем новые поддельные изображения для генератора
            fake_imgs_gen = generator(z)
            g_loss = -torch.mean(discriminator(fake_imgs_gen))
            g_loss.backward()
            optimizer_G.step()

        # Вывод прогресса
        if i % 100 == 0:
            print(f"[Epoch {epoch}/{config.num_epochs}] [Batch {i}/{len(train_loader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

# Визуализация сгенерированных изображений
z = Variable(Tensor(np.random.normal(0, 1, (16, config.noise_size)))).to(device)
gen_imgs = generator(z)
gen_imgs = gen_imgs.view(gen_imgs.size(0), 28, 28)

plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(gen_imgs[i].detach().cpu().numpy(), cmap='gray')  # Перенос на CPU для визуализации
    plt.axis('off')
plt.show()

  z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], config.noise_size)))).to(device)


[Epoch 0/10] [Batch 0/3750] [D loss: -0.04200122505426407] [G loss: 0.01035439595580101]
[Epoch 0/10] [Batch 100/3750] [D loss: -0.1731003224849701] [G loss: 0.08743385225534439]
[Epoch 0/10] [Batch 200/3750] [D loss: -1.0171973705291748] [G loss: 0.8051367998123169]
[Epoch 0/10] [Batch 300/3750] [D loss: -0.580783486366272] [G loss: -0.8433276414871216]
[Epoch 0/10] [Batch 400/3750] [D loss: -0.011880559846758842] [G loss: 0.10061788558959961]
[Epoch 0/10] [Batch 500/3750] [D loss: -0.03592371940612793] [G loss: -0.7771020531654358]
[Epoch 0/10] [Batch 600/3750] [D loss: -0.2191803753376007] [G loss: -0.19149842858314514]
[Epoch 0/10] [Batch 700/3750] [D loss: -0.04937458783388138] [G loss: 0.09418998658657074]
[Epoch 0/10] [Batch 800/3750] [D loss: -0.39514318108558655] [G loss: -0.24274247884750366]
[Epoch 0/10] [Batch 900/3750] [D loss: -0.4713520407676697] [G loss: 1.1858477592468262]
[Epoch 0/10] [Batch 1000/3750] [D loss: -0.28361696004867554] [G loss: -0.028889410197734833]
[Ep

[Epoch 2/10] [Batch 1500/3750] [D loss: 0.06079000234603882] [G loss: 0.5674723386764526]
[Epoch 2/10] [Batch 1600/3750] [D loss: -0.059599995613098145] [G loss: 0.5073263645172119]
[Epoch 2/10] [Batch 1700/3750] [D loss: -0.028300940990447998] [G loss: -0.1871200054883957]
[Epoch 2/10] [Batch 1800/3750] [D loss: -0.053586721420288086] [G loss: -2.125556230545044]
[Epoch 2/10] [Batch 1900/3750] [D loss: -0.09528446197509766] [G loss: -0.35791313648223877]
[Epoch 2/10] [Batch 2000/3750] [D loss: -0.010291635990142822] [G loss: 0.6099671125411987]
[Epoch 2/10] [Batch 2100/3750] [D loss: -0.23707059025764465] [G loss: 0.7361066341400146]
[Epoch 2/10] [Batch 2200/3750] [D loss: -0.14030885696411133] [G loss: -0.9493051767349243]
[Epoch 2/10] [Batch 2300/3750] [D loss: -0.12369069457054138] [G loss: -0.3546847403049469]
[Epoch 2/10] [Batch 2400/3750] [D loss: -0.2580986022949219] [G loss: -0.43570512533187866]
[Epoch 2/10] [Batch 2500/3750] [D loss: -0.20661848783493042] [G loss: -0.7358745

[Epoch 4/10] [Batch 2900/3750] [D loss: 0.017080307006835938] [G loss: 0.015188068151473999]
[Epoch 4/10] [Batch 3000/3750] [D loss: -0.07230830192565918] [G loss: 1.3973942995071411]
[Epoch 4/10] [Batch 3100/3750] [D loss: -0.5918034315109253] [G loss: 2.2721548080444336]
[Epoch 4/10] [Batch 3200/3750] [D loss: -0.026296377182006836] [G loss: -0.3170163929462433]
[Epoch 4/10] [Batch 3300/3750] [D loss: -0.20544064044952393] [G loss: 0.15789783000946045]
[Epoch 4/10] [Batch 3400/3750] [D loss: -0.004519820213317871] [G loss: 0.296251505613327]
[Epoch 4/10] [Batch 3500/3750] [D loss: -0.024879667907953262] [G loss: -0.007042234297841787]
[Epoch 4/10] [Batch 3600/3750] [D loss: 0.1984536349773407] [G loss: -0.25843510031700134]
[Epoch 4/10] [Batch 3700/3750] [D loss: -0.00458449125289917] [G loss: -0.43060046434402466]
[Epoch 5/10] [Batch 0/3750] [D loss: -0.07321912050247192] [G loss: -0.32612594962120056]
[Epoch 5/10] [Batch 100/3750] [D loss: -0.02584215998649597] [G loss: -0.20687428

[Epoch 7/10] [Batch 500/3750] [D loss: -0.05206114053726196] [G loss: 0.24738964438438416]
[Epoch 7/10] [Batch 600/3750] [D loss: 0.012649953365325928] [G loss: -0.4579246938228607]
[Epoch 7/10] [Batch 700/3750] [D loss: 0.2541317939758301] [G loss: 1.3889122009277344]
[Epoch 7/10] [Batch 800/3750] [D loss: 0.17884927988052368] [G loss: 0.6247106790542603]
[Epoch 7/10] [Batch 900/3750] [D loss: -0.037195056676864624] [G loss: 0.16082951426506042]
[Epoch 7/10] [Batch 1000/3750] [D loss: -0.742232084274292] [G loss: 2.438591480255127]
[Epoch 7/10] [Batch 1100/3750] [D loss: 0.014707803726196289] [G loss: 0.2458633929491043]
[Epoch 7/10] [Batch 1200/3750] [D loss: -0.6909189224243164] [G loss: -1.4432116746902466]
[Epoch 7/10] [Batch 1300/3750] [D loss: 0.006399750709533691] [G loss: 0.020867131650447845]
[Epoch 7/10] [Batch 1400/3750] [D loss: 0.078066386282444] [G loss: -0.0028348376508802176]
[Epoch 7/10] [Batch 1500/3750] [D loss: -0.013565748929977417] [G loss: 0.12783363461494446]
[

[Epoch 9/10] [Batch 1900/3750] [D loss: -0.1434633731842041] [G loss: 0.379936158657074]
[Epoch 9/10] [Batch 2000/3750] [D loss: 0.022326409816741943] [G loss: -0.5591810941696167]
[Epoch 9/10] [Batch 2100/3750] [D loss: -0.10522960126399994] [G loss: -0.05189116671681404]
[Epoch 9/10] [Batch 2200/3750] [D loss: -0.2744654417037964] [G loss: -0.11260104179382324]
[Epoch 9/10] [Batch 2300/3750] [D loss: -0.04838836193084717] [G loss: -0.8663029074668884]
[Epoch 9/10] [Batch 2400/3750] [D loss: -0.05054086446762085] [G loss: -0.24177628755569458]
[Epoch 9/10] [Batch 2500/3750] [D loss: -0.1795225739479065] [G loss: -0.8157445192337036]
[Epoch 9/10] [Batch 2600/3750] [D loss: -0.029933661222457886] [G loss: 0.2857365906238556]
[Epoch 9/10] [Batch 2700/3750] [D loss: 0.07682500779628754] [G loss: 0.0703272745013237]
[Epoch 9/10] [Batch 2800/3750] [D loss: 0.06161035597324371] [G loss: 0.1399630457162857]
[Epoch 9/10] [Batch 2900/3750] [D loss: -0.028930924832820892] [G loss: 0.013900459744

# Замените клиппинг весов на штраф градентов. (10 баллов)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
import numpy as np

# Определение устройства (GPU или CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Конфигурация
class Config:
    batch_size = 16
    noise_size = 100
    num_epochs = 10
    learning_rate = 0.0002
    lambda_gp = 10

config = Config()

# Загрузка данных
train_dataset = torchvision.datasets.FashionMNIST(root='./fashion_mnist', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

# Определение Генератора
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(config.noise_size, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 28*28),
            nn.Tanh())

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# Определение Дискриминатора
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1))

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Инициализация моделей
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Оптимизаторы
optimizer_G = optim.Adam(generator.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))

# Функция для вычисления штрафа градиента
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
    alpha = alpha.expand_as(real_samples)
    interpolated = Variable(alpha * real_samples + (1 - alpha) * fake_samples, requires_grad=True)
    d_interpolated = D(interpolated)
    fake = Variable(torch.ones(d_interpolated.size(), device=real_samples.device), requires_grad=False)

    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Тренировка
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
n_critic = 5

for epoch in range(config.num_epochs):
    for i, (imgs, _) in enumerate(train_loader):

        # Конфигурация входных данных
        real_imgs = Variable(imgs.type(Tensor)).to(device)
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], config.noise_size)))).to(device)

        # Генерация поддельных изображений
        fake_imgs = generator(z)

        # ---------------------
        #  Тренировка Дискриминатора
        # ---------------------

        optimizer_D.zero_grad()

        # Реальные изображения
        real_validity = discriminator(real_imgs)
        # Поддельные изображения
        fake_validity = discriminator(fake_imgs)
        # Штраф градиента
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
        # Функция потерь дискриминатора
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + config.lambda_gp * gradient_penalty

        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Тренировка Генератора
        # -----------------

        if i % n_critic == 0:
            optimizer_G.zero_grad()

            # Генерируем изображение
            generated_imgs = generator(z)
            # Функция потерь генератора
            g_loss = -torch.mean(discriminator(generated_imgs))

            g_loss.backward()
            optimizer_G.step()

        # Вывод прогресса
        if i % 100 == 0:
            print(f"[Epoch {epoch}/{config.num_epochs}] [Batch {i}/{len(train_loader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

# Визуализация сгенерированных изображений
z = Variable(Tensor(np.random.normal(0, 1, (16, config.noise_size)))).to(device)
gen_imgs = generator(z)
gen_imgs = gen_imgs.view(gen_imgs.size(0), 28, 28)

plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(gen_imgs[i].detach().cpu().numpy(), cmap='gray')
    plt.axis('off')
plt.show()
plt.show()

# Добавьте лейблы в WGAN, тем самым решая задачу условной генерации. (30 баллов)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt

# Конфигурация
class Config:
    batch_size = 16
    noise_size = 50
    num_epochs = 10
    learning_rate = 0.0002
    num_classes = 10  # Количество классов в FashionMNIST
    embed_size = 50   # Размерность вектора встраивания для лейблов

config = Config()

# Датасет и DataLoader с использованием PyTorch Lightning
class FashionMNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./fashion_mnist', batch_size=config.batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def setup(self, stage=None):
        self.fashion_mnist_train = FashionMNIST(self.data_dir, train=True, download=True, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.fashion_mnist_train, batch_size=self.batch_size, shuffle=True)

# Определение Генератора
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(config.num_classes, config.embed_size)
        self.model = nn.Sequential(
            nn.Linear(config.noise_size + config.embed_size, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 28*28),
            nn.Tanh())

    def forward(self, z, labels):
        c = self.label_emb(labels)
        x = torch.cat([z, c], 1)
        img = self.model(x)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# Определение Дискриминатора
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(config.num_classes, config.embed_size)
        self.model = nn.Sequential(
            nn.Linear(28*28 + config.embed_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1))

    def forward(self, img, labels):
        img_flat = img.view(img.size(0), -1)
        c = self.label_emb(labels)
        x = torch.cat([img_flat, c], 1)
        validity = self.model(x)
        return validity

class GAN(pl.LightningModule):
    def __init__(self):
        super(GAN, self).__init__()
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.automatic_optimization = False  # Включаем ручное управление оптимизациями

    def forward(self, z, labels):
        return self.generator(z, labels)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        z = torch.randn(imgs.shape[0], config.noise_size).type_as(imgs)

        # Получаем оптимизаторы
        opt_g, opt_d = self.optimizers()

        # Обучение дискриминатора
        real_loss = torch.mean(self.discriminator(imgs, labels))
        fake_imgs = self(z, labels).detach()
        fake_loss = torch.mean(self.discriminator(fake_imgs, labels))
        d_loss = real_loss - fake_loss
        opt_d.zero_grad()
        self.manual_backward(d_loss)
        opt_d.step()

        # Обучение генератора
        fake_imgs = self(z, labels)
        g_loss = -torch.mean(self.discriminator(fake_imgs, labels))
        opt_g.zero_grad()
        self.manual_backward(g_loss)
        opt_g.step()

        self.log('d_loss', d_loss, prog_bar=True)
        self.log('g_loss', g_loss, prog_bar=True)

    def configure_optimizers(self):
        opt_g = optim.Adam(self.generator.parameters(), lr=config.learning_rate)
        opt_d = optim.Adam(self.discriminator.parameters(), lr=config.learning_rate)
        return opt_g, opt_d

# Инициализация и запуск тренировки
dm = FashionMNISTDataModule()
model = GAN()
trainer = pl.Trainer(max_epochs=config.num_epochs)
trainer.fit(model, dm)

# Визуализация сгенерированных изображений
z = torch.randn(16, config.noise_size)
labels = torch.randint(0, config.num_classes, (16,))
gen_imgs = model.generator(z, labels)
gen_imgs = gen_imgs.view(gen_imgs.size(0), 28, 28)

plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(gen_imgs[i].detach().cpu().numpy(), cmap='gray')
    plt.axis('off')
plt.show()

# Попытка объедения 2 и 3 заданий вместе!

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt

# Конфигурация
class Config:
    batch_size = 16
    noise_size = 100
    num_epochs = 5
    learning_rate = 0.0002
    lambda_gp = 10

config = Config()

# Датасет и DataLoader с использованием PyTorch Lightning
class FashionMNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./fashion_mnist', batch_size=config.batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def setup(self, stage=None):
        self.fashion_mnist_train = FashionMNIST(self.data_dir, train=True, download=True, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.fashion_mnist_train, batch_size=self.batch_size, shuffle=True)

# Определение Генератора
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(config.noise_size, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 28*28),
            nn.Tanh())

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img
    
# Определение Дискриминатора
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1))

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity
    
# LightningModule для GAN
class GAN(pl.LightningModule):
    def __init__(self):
        super(GAN, self).__init__()
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.automatic_optimization = False

    def compute_gradient_penalty(self, real_samples, fake_samples):
        alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
        alpha = alpha.expand_as(real_samples)
        interpolated = Variable(alpha * real_samples + (1 - alpha) * fake_samples, requires_grad=True)
        d_interpolated = self.discriminator(interpolated)
        fake = Variable(torch.ones(d_interpolated.size(), device=real_samples.device), requires_grad=False)

        gradients = torch.autograd.grad(
            outputs=d_interpolated,
            inputs=interpolated,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def training_step(self, batch, batch_idx):
        imgs = batch[0]
        z = torch.randn(imgs.shape[0], config.noise_size).type_as(imgs)

        # Получаем оптимизаторы
        opt_g, opt_d = self.optimizers()

        # Обучение дискриминатора
        real_imgs = imgs
        fake_imgs = self.generator(z)
        
        real_validity = self.discriminator(real_imgs)
        fake_validity = self.discriminator(fake_imgs.detach())
        gradient_penalty = self.compute_gradient_penalty(real_imgs, fake_imgs.detach())
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + config.lambda_gp * gradient_penalty

        opt_d.zero_grad()
        self.manual_backward(d_loss)
        opt_d.step()

        # Обучение генератора
        if batch_idx % 5 == 0:
            generated_imgs = self.generator(z)
            g_loss = -torch.mean(self.discriminator(generated_imgs))

            opt_g.zero_grad()
            self.manual_backward(g_loss)
            opt_g.step()

            self.log('g_loss', g_loss, prog_bar=True)

        self.log('d_loss', d_loss, prog_bar=True)

    def configure_optimizers(self):
        opt_g = optim.Adam(self.generator.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))
        opt_d = optim.Adam(self.discriminator.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))
        return opt_g, opt_d
    
# Инициализация и запуск тренировки
dm = FashionMNISTDataModule()
model = GAN()
trainer = pl.Trainer(max_epochs=config.num_epochs)
trainer.fit(model, dm)

# Визуализация сгенерированных изображений
z = torch.randn(16, config.noise_size)
gen_imgs = model.generator(z)
gen_imgs = gen_imgs.view(gen_imgs.size(0), 28, 28)

plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(gen_imgs[i].detach().cpu().numpy(), cmap='gray')
    plt.axis('off')
plt.show()