In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt

Для расчета на GPU

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

Загрузка датасета и формирование даталодера

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

Классы генератора и дискриминатора (с линейными слоями)

In [6]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [10]:
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)
criterion = nn.BCELoss()  
G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)

In [11]:
for epoch in range(0, 200+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        # _______Discriminator____________
        D.zero_grad()
        
        x_real, y_real = x.view(-1, mnist_dim).to(device), torch.ones(100, 1).to(device)

        D_output = D(x_real)
        D_real_loss = criterion(D_output, y_real)

        z = torch.randn(100, z_dim).to(device)
        x_fake, y_fake = G(z), torch.zeros(100, 1).to(device)
        D_output = D(x_fake)
        D_fake_loss = criterion(D_output, y_fake)

        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optimizer.step()
        D_losses.append(D_loss.item())
        
        # _______Generator____________
        G.zero_grad()

        z = torch.randn(100, z_dim).to(device)
        y = torch.ones(100, 1).to(device)

        G_output = G(z)
        D_output = D(G_output)
        G_loss = criterion(D_output, y)
        G_loss.backward()
        G_optimizer.step() 
        
        G_losses.append(G_loss.item())

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), 200, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

[0/200]: loss_d: 0.733, loss_g: 4.224
[1/200]: loss_d: 0.866, loss_g: 2.030
[2/200]: loss_d: 0.738, loss_g: 2.510
[3/200]: loss_d: 0.549, loss_g: 2.893
[4/200]: loss_d: 0.533, loss_g: 2.828
[5/200]: loss_d: 0.505, loss_g: 2.939
[6/200]: loss_d: 0.523, loss_g: 2.789
[7/200]: loss_d: 0.535, loss_g: 2.717
[8/200]: loss_d: 0.557, loss_g: 2.627
[9/200]: loss_d: 0.570, loss_g: 2.730
[10/200]: loss_d: 0.649, loss_g: 2.409
[11/200]: loss_d: 0.684, loss_g: 2.219
[12/200]: loss_d: 0.761, loss_g: 1.968
[13/200]: loss_d: 0.774, loss_g: 1.887
[14/200]: loss_d: 0.812, loss_g: 1.836
[15/200]: loss_d: 0.822, loss_g: 1.845
[16/200]: loss_d: 0.808, loss_g: 1.920
[17/200]: loss_d: 0.821, loss_g: 1.871
[18/200]: loss_d: 0.767, loss_g: 2.032
[19/200]: loss_d: 0.814, loss_g: 1.880
[20/200]: loss_d: 0.875, loss_g: 1.693
[21/200]: loss_d: 0.882, loss_g: 1.642
[22/200]: loss_d: 0.920, loss_g: 1.566
[23/200]: loss_d: 0.941, loss_g: 1.539
[24/200]: loss_d: 0.961, loss_g: 1.502
[25/200]: loss_d: 0.946, loss_g: 1.

In [None]:
with torch.no_grad():
            test_z = torch.randn(100, z_dim).to(device)
            generated = G(test_z)
            save_image(generated.view(generated.size(0), 1, 28, 28), './gan_samples/{}'.format(epoch) + '.png')