In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
from torchvision import transforms
from matplotlib import pyplot as plt
from torchvision.datasets import ImageFolder
from torchsummary import  summary
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from torch.cuda import amp
import torch.nn.functional as F
import torch
import torchvision.utils as vutils
import random
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import pickle

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

предупреждаю ошибки при загрузке датасетов

In [None]:
cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

Инициализирую основные переменные

In [None]:
DEVICE = torch.device('cuda')
EPOCHS = 350
BATCH_SIZE = 64
IMAGE_SIZE = 128
NUM_CLASSES = 27
FEATURE_MAP_GEN = 64
FEATURE_MAP_DISC = 32
NUM_CHANNELS = 3
NOISE_SIZE = 150

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
!wget http://web.fsktm.um.edu.my/~cschan/source/ICIP2017/wikiart.zip

In [None]:
%%time
!unzip /content/wikiart.zip -d /content/train

In [None]:
train_directory = '/content/train/wikiart'

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)

In [None]:
def denorm(img_tensors):
    stats = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
    return img_tensors * stats[1][0] + stats[0][0]

In [None]:
%%time
dataset_fold  = ImageFolder(root = train_directory, transform = transform)
dataset_norm = [data for data in dataset_fold]

In [None]:
 len(dataset_norm)

58986

In [None]:
dataloader = DataLoader(dataset_norm, batch_size = BATCH_SIZE, shuffle = True, pin_memory=True, drop_last=True, num_workers=2)

Гауссовский шум добавляется в каждом слое дискриминаторе, что позволяет повысить качество обучения

In [None]:
class GaussianNoise(nn.Module):        
    def __init__(self, std=0.1, decay_rate=0):
        super().__init__()
        self.std = std
        self.decay_rate = decay_rate

    def decay_step(self):
        self.std = max(self.std - self.decay_rate, 0)

    def forward(self, x):
        if self.training:
            return x + torch.empty_like(x).normal_(std=self.std)
        else:
            return x

веса инициализируются из нормального распределения

In [None]:
@torch.no_grad()
def weights_init(model):
    classname = model.__class__.__name__
    if 'Conv' in classname:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif 'BatchNorm' in classname:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, ngpu=1):
        super(Generator, self).__init__()
        self.label_emb = nn.Embedding(NUM_CLASSES, NUM_CLASSES)
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # размер шума 150
            nn.ConvTranspose2d(NOISE_SIZE + NUM_CLASSES, FEATURE_MAP_GEN * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(FEATURE_MAP_GEN * 8),
            nn.ReLU(True),
            # 512 x 4 x 4
            nn.ConvTranspose2d(FEATURE_MAP_GEN * 8, FEATURE_MAP_GEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_MAP_GEN * 4),
            nn.ReLU(True),
            # 256 x 8 x 8
            nn.ConvTranspose2d(FEATURE_MAP_GEN * 4, FEATURE_MAP_GEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_MAP_GEN* 2),
            nn.ReLU(True),
            # 128 x 16 x 16
            nn.ConvTranspose2d(FEATURE_MAP_GEN * 2, FEATURE_MAP_GEN, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_MAP_GEN),
            nn.ReLU(True),
            # 64 x 32 x 32
            nn.ConvTranspose2d(FEATURE_MAP_GEN, NUM_CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()
            # выход сети 3 x 64 x 64
        )
    def forward(self, noise_input, labels):
        #конкатенируем метки с входынм шумом
        gen_input = torch.cat((self.label_emb(labels).unsqueeze(2).unsqueeze(3), noise_input), 1)

        img = self.main(gen_input)

        img = img.view(img.size(0), *(NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))
        return img

In [None]:

class Discriminator(nn.Module):
    def __init__(self, ngpu=1):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(NUM_CLASSES, FEATURE_MAP_DISC*64)
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # 3 x 64 x 64
            GaussianNoise(),
            nn.Conv2d(NUM_CHANNELS, FEATURE_MAP_DISC, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 32 x 32 x 32
            GaussianNoise(),
            nn.Conv2d(FEATURE_MAP_DISC, FEATURE_MAP_DISC * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_MAP_DISC * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 64 x 16 x 16
            GaussianNoise(),
            nn.Conv2d(FEATURE_MAP_DISC * 2, FEATURE_MAP_DISC * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_MAP_DISC * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 128 x 8 x 8
            GaussianNoise(),
            nn.Conv2d(FEATURE_MAP_DISC * 4, FEATURE_MAP_DISC * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(FEATURE_MAP_DISC * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 x 4 x 4

            GaussianNoise(),
            nn.Conv2d(FEATURE_MAP_DISC * 8, FEATURE_MAP_DISC * 16, 4, 2, 1, bias=False), 
            nn.BatchNorm2d(FEATURE_MAP_DISC * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # 1024 x 2 x 2
            nn.Flatten()
        )
        self.linear = nn.Sequential(
        nn.Linear(FEATURE_MAP_DISC*128, FEATURE_MAP_DISC*16),    
        nn.LeakyReLU(0.2, inplace=True),   
        nn.Linear(FEATURE_MAP_DISC*16, 1),
        nn.Sigmoid()    
        )

    def forward(self, input, labels):
        disc_out = self.main(input)
        #добавляем метки в дискриминаторе
        linear_input = torch.cat((self.label_emb(labels), disc_out), 1)
        linear_output = self.linear(linear_input.squeeze())

        return linear_output.unsqueeze(2).unsqueeze(3)

In [None]:
netG = Generator().to(DEVICE)
netD = Discriminator().to(DEVICE)
#инициализируем веса
netG.apply(weights_init)
netD.apply(weights_init)

criterion = nn.BCELoss()

наилучшее качество достугнуто при выбранных параметрах оптимихатора, добавление sheduler не дало явных улучшений

In [None]:
optG = torch.optim.Adam(netG.parameters(), lr= 0.0001, betas= (0.5, 0.999))
optD = torch.optim.Adam(netD.parameters(), lr= 0.0001, betas= (0.5, 0.999))

фиксированные значения для шуба и меток

In [None]:
fixed_noise = torch.randn(16, NOISE_SIZE, 1, 1, device=DEVICE)
fixed_labels = torch.tensor(list(range(16)), device = DEVICE)

In [None]:
loss_gen = []
loss_disc = []

def train_epoch(train_loader, netG, netD, optG, optD, noise_dim, epochs, batch_size, device=DEVICE):

    netD.train()
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    for epoch in tqdm(range(epochs)):
        netG.train()
        torch.cuda.empty_cache()
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        for  batch in dataloader:
            optD.zero_grad()
            real_images = batch[0].to(DEVICE)
            labels = batch[1].to(DEVICE)
            
            noise = torch.randn(batch_size,noise_dim,1,1).to(DEVICE)
            #обучение дискриминатора на реальных картинках
            real_preds = netD(real_images,labels)
            real_targets = torch.ones(real_images.size(0), 1,1,1, device=device)
            real_loss = criterion(real_preds, real_targets.uniform_(0.9, 1.0))
            cur_real_score = torch.mean(real_preds).item()
            #обучение дискриминатора на фейковых картинках
            gen_fake = netG(noise,labels)
            fake_targets = torch.zeros(gen_fake.size(0), 1,1,1, device=device)
            fake_out = netD(gen_fake,labels)
            fake_loss = criterion(fake_out, fake_targets.uniform_(0.0, 0.1))
            cur_fake_score = torch.mean(fake_out).item()

            real_score_per_epoch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)

            loss_d = real_loss + fake_loss
            loss_d.backward()
            optD.step()
            loss_d_per_epoch.append(loss_d.item())
            
            #обучение генератора

            optG.zero_grad()

            noise = torch.randn(batch_size,noise_dim,1,1).to(DEVICE)
            gen_fake = netG(noise,labels)

            #обманываем дискриминатор
            preds = netD(gen_fake,labels)
            targets = torch.ones(batch_size, 1,1,1, device=device)
            loss_g = criterion(preds,targets)
            
            loss_g.backward()
            optG.step()

            loss_g_per_epoch.append(loss_g.item())
       
        losses_g.append(np.mean(loss_g_per_epoch))
        losses_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))
            
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, 
            losses_g[-1], losses_d[-1], real_scores[-1], fake_scores[-1]))
        if epoch % 15 == 0:
          netG.eval()
          with torch.no_grad():
            fake_images_pictured =netG(fixed_noise,fixed_labels[)
          plt.figure(figsize=(16,16))
          #отражаем картинки каждый 15 эпох для контроля обучения
          plt.imshow(np.transpose(vutils.make_grid(fake_images_pictured.detach(), padding=2, normalize=True).cpu(),(1,2,0)))
          plt.show();
          torch.save({
            'model_netG':netG.state_dict(),
            'model_netD':netD.state_dict(),
            'optimizer_Gen':optG.state_dict(),
            'optimizer_Disc':optD.state_dict()},'./checkpoint_norm.tar'
              )
    return losses_g, losses_d, real_scores, fake_scores

In [None]:
losses_g2, losses_d2, real_scores2, fake_scores2 = train_epoch(dataloader, netG, netD, optG, optD, NOISE_SIZE, epochs=EPOCHS, batch_size = BATCH_SIZE)