In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import save_image

Для расчетов на GPU

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

Загрузка датасета и формирование даталодера

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=False)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

Генератор и дискриминатор с добавлением вектора условий

In [61]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()     
        self.embed = nn.Embedding(10, 10)
        self.seq = nn.Sequential(
            nn.ConvTranspose2d(110, 64, kernel_size=7, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, x, labels): 
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)  # (N, embed_dim, 1, 1)
        x = torch.cat([x, embedding], dim=1)  # (N, z_dim + embed_dim, 1, 1)
        return self.seq(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.embed = nn.Embedding(10, 784)
        self.seq = nn.Sequential(
            nn.Conv2d(1+1, 16, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def forward(self, x, labels):
        embedding = self.embed(labels).view(-1, 1, 28, 28)
        # Конкатенация изображения и embedding
        x = torch.cat([x, embedding], dim=1)  # (N, channels_img + 1, 28, 28)
        return self.seq(x)

In [62]:
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
G = Generator().to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss()  
G_optimizer = optim.Adam(G.parameters(), lr = 0.0004, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr = 0.0003, betas=(0.5, 0.999))

In [63]:
for epoch in range(0, 200+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, label) in enumerate(train_loader):
        # _______Discriminator____________
        D.zero_grad()
        
        label = label.to(device)
        x_real, y_real = x.to(device), torch.ones(100, 1).to(device)

        D_output = D(x_real, label)
        D_real_loss = criterion(D_output.view(-1).reshape(100, 1), y_real)

        z = torch.randn((100, z_dim, 1, 1)).to(device)
        x_fake, y_fake = G(z, label), torch.zeros(100, 1).to(device)

        D_output = D(x_fake, label)
        D_fake_loss = criterion(D_output.view(-1).reshape(100, 1), y_fake)

        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optimizer.step()

        D_losses.append(D_loss.item())
        
        # _______Generator____________
        G.zero_grad()

        y = torch.ones(100, 1).to(device)

        G_output = G(z, label)
        D_output = D(G_output, label)
        G_loss = criterion(D_output.view(-1).reshape(100, 1), y)
        G_loss.backward()
        G_optimizer.step() 
        G_losses.append(G_loss.item())

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), 200, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    if epoch % 2 == 0:
        with torch.no_grad():
            lb = torch.tensor(list(range(10))*10, dtype=torch.int64)
            generated = G(z, lb.to(device))
            save_image(generated.view(generated.size(0), 1, 28, 28), './conv_cegan_samples/{}'.format(epoch) + '.png')

[0/200]: loss_d: 1.321, loss_g: 0.796
[1/200]: loss_d: 1.320, loss_g: 0.815
[2/200]: loss_d: 1.378, loss_g: 0.749
[3/200]: loss_d: 1.384, loss_g: 0.730
[4/200]: loss_d: 1.387, loss_g: 0.722
[5/200]: loss_d: 1.390, loss_g: 0.719
[6/200]: loss_d: 1.390, loss_g: 0.714
[7/200]: loss_d: 1.390, loss_g: 0.711
[8/200]: loss_d: 1.390, loss_g: 0.710
[9/200]: loss_d: 1.390, loss_g: 0.709
[10/200]: loss_d: 1.389, loss_g: 0.708
[11/200]: loss_d: 1.389, loss_g: 0.707


KeyboardInterrupt: 