In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import save_image

Для расчетан а GPU

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

Загрузка датасета и формирование даталодера

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=False)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

Генератор и дискриминатор (со сверточными слоями в структуре)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()       
        self.seq = nn.Sequential(
            nn.ConvTranspose2d(100, 64, kernel_size=7, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, x): 
        return self.seq(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.seq(x)

In [7]:
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
G = Generator().to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss()  
G_optimizer = optim.Adam(G.parameters(), lr = 0.0003, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr = 0.0003, betas=(0.5, 0.999))

In [8]:
for epoch in range(0, 200+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        # _______Discriminator____________
        D.zero_grad()
        
        x_real, y_real = x.to(device), torch.ones(100, 1).to(device)

        D_output = D(x_real)
        D_real_loss = criterion(D_output.view(-1).reshape(100, 1), y_real)

        z = torch.randn((100, z_dim, 1, 1)).to(device)
        x_fake, y_fake = G(z), torch.zeros(100, 1).to(device)

        D_output = D(x_fake)
        D_fake_loss = criterion(D_output.view(-1).reshape(100, 1), y_fake)

        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())
        
        # _______Generator____________
        G.zero_grad()

        y = torch.ones(100, 1).to(device)

        G_output = G(z)
        D_output = D(G_output)
        G_loss = criterion(D_output.view(-1).reshape(100, 1), y)
        G_loss.backward()
        G_optimizer.step() 
        G_losses.append(G_loss.item())

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), 200, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    if epoch % 2 == 0:
        with torch.no_grad():
            generated = G(z)
            save_image(generated.view(generated.size(0), 1, 28, 28), './conv_gan_samples/{}'.format(epoch) + '.png')

[0/200]: loss_d: 0.997, loss_g: 1.145
[1/200]: loss_d: 0.994, loss_g: 1.140
[2/200]: loss_d: 1.016, loss_g: 1.137
[3/200]: loss_d: 1.003, loss_g: 1.160
[4/200]: loss_d: 0.992, loss_g: 1.192
[5/200]: loss_d: 0.992, loss_g: 1.212
[6/200]: loss_d: 0.987, loss_g: 1.229
[7/200]: loss_d: 0.978, loss_g: 1.249
[8/200]: loss_d: 0.970, loss_g: 1.264
[9/200]: loss_d: 0.969, loss_g: 1.279
[10/200]: loss_d: 0.957, loss_g: 1.303
[11/200]: loss_d: 0.959, loss_g: 1.309
[12/200]: loss_d: 0.946, loss_g: 1.324
[13/200]: loss_d: 0.946, loss_g: 1.348
[14/200]: loss_d: 0.938, loss_g: 1.347
[15/200]: loss_d: 0.926, loss_g: 1.373
[16/200]: loss_d: 0.925, loss_g: 1.385
[17/200]: loss_d: 0.919, loss_g: 1.396
[18/200]: loss_d: 0.917, loss_g: 1.400
[19/200]: loss_d: 0.913, loss_g: 1.408
[20/200]: loss_d: 0.909, loss_g: 1.409
[21/200]: loss_d: 0.902, loss_g: 1.434
[22/200]: loss_d: 0.894, loss_g: 1.446
[23/200]: loss_d: 0.889, loss_g: 1.462
[24/200]: loss_d: 0.887, loss_g: 1.458
[25/200]: loss_d: 0.888, loss_g: 1.

KeyboardInterrupt: 