In [None]:
import torch
import torch.nn as nn

import os

import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class BasicBlock(nn.Module):
    def __init__(self, inp_ch, out_ch, kernel_size, padding, stride, activation = 'LeakyReLu') -> None:
        super().__init__()
        self.conv = nn.Conv2d(inp_ch,out_ch, kernel_size, padding=2,stride=2)
        self.bn = nn.BatchNorm2d(out_ch,momentum=0.9)
        if activation == 'LeakyReLu':
           self.activation = nn.LeakyReLU()
        elif activation == 'Tanh':
           self.activation = nn.Tanh()
        else:
           raise(ValueError, 'Unknown error func')

        self.layer = nn.Sequential(
           self.conv,
           self.bn,
           self.activation
          
        )
    def forward(self, x):
        return self.layer(x)

class BasicDeconv(nn.Module):
    def __init__(self, inp_ch, out_ch, kernel_size, padding, stride, activation = 'LeakyReLu') -> None:
        super().__init__()
        self.deconv = nn.ConvTranspose2d(inp_ch,out_ch, kernel_size, padding=2,stride=2)
        self.bn = nn.BatchNorm2d(out_ch,momentum=0.9)
        if activation == 'LeakyReLu':
           self.activation = nn.LeakyReLU()
        elif activation == 'Tanh':
           self.activation = nn.Tanh()
        elif activation == 'Sigmoid':
          self.activation = nn.Sigmoid()
        else:
           raise(ValueError, 'Unknown error func')

        self.layer = nn.Sequential(
           self.deconv,
           self.bn,
           self.activation
          
        )
    def forward(self, x):
        return self.layer(x)

In [None]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder,self).__init__()
    self.layer1 = BasicBlock(1, 64, 5, padding=2, stride=2)
    self.layer2 = BasicBlock(64, 128, 5, padding=2, stride=2)
    self.layer3 = BasicBlock(128, 256, 5, padding=2, stride=2)

    self.fc1=nn.Linear(256*8*8,2048)
    self.bn4=nn.BatchNorm1d(2048,momentum=0.9)
    self.fc_mean=nn.Linear(2048,128)
    self.fc_logvar=nn.Linear(2048,128)    
  
  def forward(self,x):
    batch_size=x.size()[0]
    out = self.layer1(x)
    out = self.layer2(out)
    out = self.layer3(out)
    out = out.view(batch_size, -1)
    out=torch.relu(self.bn4(self.fc1(out)))
    mean=self.fc_mean(out)
    logvar=self.fc_logvar(out)
    return mean,logvar

class Decoder(nn.Module):
  def __init__(self):
    super(Decoder,self).__init__()
    self.fc1=nn.Linear(128,8*8*256)
    self.bn1=nn.BatchNorm1d(8*8*256,momentum=0.9)
    self.relu=nn.LeakyReLU(0.2)

    self.deconv1 = BasicDeconv(256,256,6, stride=2, padding=2)
    self.deconv2 = BasicDeconv(256,128,6, stride=2, padding=2)
    self.deconv3 = BasicDeconv(128,32,6, stride=2, padding=2)
    self.deconv4 = BasicDeconv(32,1,5, stride=1, padding=2, activation='Tanh')

  def forward(self,x):
    batch_size=x.size()[0]
    x=self.relu(self.bn1(self.fc1(x)))
    out=x.view(-1,256,8,8)
    out = self.deconv1(out)
    out = self.deconv2(out)
    out = self.deconv3(out)
    out = self.deconv4(out)
    return x

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.layer1 = BasicBlock(1, 32, 5, padding=2, stride=1)
    self.layer2 = BasicBlock(32, 128, 5, padding=2, stride=2)
    self.layer3 = BasicBlock(128, 256, 5, padding=2, stride=2)
    self.layer4 = BasicBlock(256, 256, 5, padding=2, stride=2)
    self.fc1=nn.Linear(8*8*256,512)
    self.bn=nn.BatchNorm1d(512,momentum=0.9)
    self.fc2=nn.Linear(512,1)
    self.sigmoid=nn.Sigmoid()
    self.relu = nn.LeakyReLU()

  def forward(self,x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x=x.view(-1,256*8*8)
    x1=x
    x=self.relu(self.bn(self.fc1(x)))
    x=self.sigmoid(self.fc2(x))
    return x,x1

class VAE_GAN(nn.Module):
  def __init__(self):
    super(VAE_GAN,self).__init__()
    self.encoder = Encoder()
    self.decoder = Decoder()
    self.encoder.apply(weights_init)
    self.decoder.apply(weights_init)

  def forward(self,x):
    bs = x.size()[0]
    z_mean, z_logvar = self.encoder(x)
    std = z_logvar.mul(0.5).exp_()
    # sampling 
    epsilon = torch.randn(bs,128).to(device)
    z = z_mean + std * epsilon
    x_tilda = self.decoder(z)
    return z_mean, z_logvar, x_tilda


## Dataloader

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class MNIST_dataset(Dataset):

    def __init__(self, x):

        super().__init__()
        self.data = x

    def __getitem__(self, index):
        """Returns the tuple (img, None) with the given index."""
        img = self.data[index]
        img = torch.Tensor(np.expand_dims(img, axis=0))/255.
        t = transforms.Resize((64, 64))
        img = t(img)
        return img

    def __len__(self):
        return len(self.data)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def save_example(file_name,img):
    npimg = np.transpose(img.numpy(),(1,2,0))
    f = "./%s.png" % file_name
    fig = plt.figure(dpi=200)
    fig.suptitle(file_name, fontsize=14, fontweight='bold')
    plt.imsave(f,npimg)

## Training

In [None]:
import numpy as np
def load_data(path):
    with np.load(path) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_data('3. mnist.npz')
train_dataset = MNIST_dataset(x_train)

In [None]:
data_loader = DataLoader(train_dataset, 64)

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.autograd.set_detect_anomaly(True)

save_checkpoint = 'checkpoints'

# models
generator=VAE_GAN().to(device)
discrim=Discriminator().to(device)

# Параметры обучения
epochs=5
lr=3e-4
alpha=0.1
gamma=15

# Функции ошибки
criterion=nn.BCELoss().to(device)
optim_E=torch.optim.Adam(generator.encoder.parameters(), lr=lr)
optim_D=torch.optim.Adam(generator.decoder.parameters(), lr=lr)
optim_Dis=torch.optim.Adam(discrim.parameters(), lr=lr*alpha)

for epoch in range(epochs):
  encoder_loss_list, gan_loss_list, decoder_loss_list=[],[],[]
  dis_real_list,dis_fake_list,dis_noise_list=[],[],[]
  i = 0
  for i, data in enumerate(tqdm(data_loader), 0):

    bs=data.size()[0]
    
    ones_label=Variable(torch.ones(bs,1)).to(device)
    zeros_label=Variable(torch.zeros(bs,1)).to(device)
    zeros_label1=Variable(torch.zeros(64,1)).to(device)

    x_r = Variable(data).to(device)

    # Получаем изображения используя GT
    mean, logvar, x_f = generator(x_r)
    
    # Получаем изображения из Шума
    z_p = Variable(torch.randn(64,128)).to(device)
    x_p = generator.decoder(z_p)

    # Считаем ответы Дискриминатора
    output, x_l = discrim(x_r)
    errD_real = criterion(output, ones_label)
    dis_real_list.append(errD_real.item())

    output, x_l_tilda = discrim(x_f)
    errD_fake = criterion(output, zeros_label)
    dis_fake_list.append(errD_fake.item())

    output = discrim(x_p)[0]
    errD_noise = criterion(output, zeros_label1)
    dis_noise_list.append(errD_noise.item())


    # Обучаем дискриминатор
    gan_loss = errD_real + errD_fake + errD_noise
    gan_loss_list.append(gan_loss.item())
    optim_Dis.zero_grad()
    gan_loss.backward(retain_graph=True)
    optim_Dis.step()
    ##############################
    
    # Ответы нового дискриминатора
    output, x_l = discrim(x_r)
    errD_real = criterion(output, ones_label)

    output, x_l_tilda = discrim(x_f)
    errD_rec_enc = criterion(output, zeros_label)
    output = discrim(x_p)[0]
    errD_rec_noise = criterion(output, zeros_label1)
    gan_loss = errD_real + errD_rec_enc + errD_rec_noise
    
    # Обучаем Декодер
    rec_loss = ((x_l_tilda - x_l) ** 2).mean()
    err_dec = gamma * rec_loss - gan_loss 
    encoder_loss_list.append(rec_loss.item())
    optim_D.zero_grad()
    err_dec.backward(retain_graph=True)
    optim_D.step()
    
    # Обучаем Энкодер
    mean, logvar, x_f = generator(x_r)
    x_l_tilda = discrim(x_f)[1]
    x_l = discrim(x_r)[1]
    rec_loss = ((x_l_tilda - x_l) ** 2).mean()
    prior_loss = 1 + logvar - mean.pow(2) - logvar.exp()
    prior_loss = (-0.5 * torch.sum(prior_loss))/torch.numel(mean.data)
    encoder_loss_list.append(prior_loss.item())
    err_enc = prior_loss + 5*rec_loss

    optim_E.zero_grad()
    err_enc.backward(retain_graph=True)
    optim_E.step()
        
  print('Loss_gan: %.4f\tEncoder_loss: %.4f\tDecoder_loss: %.4f\tdis_real_loss: %0.4f\tdis_fake_loss: %.4f\tdis_prior_loss: %.4f'
                  % (np.mean(gan_loss_list), np.mean(encoder_loss_list),np.mean(encoder_loss_list), np.mean(dis_real_list), np.mean(dis_fake_list), np.mean(dis_noise_list)))
  
  # Генерация примеров
  real_batch = next(iter(data_loader))
  z_fixed=Variable(torch.randn((64,128))).to(device)
  x_fixed=Variable(real_batch[0]).to(device)
  b = generator(x_fixed)[2]
  b = b.detach()
  c = generator.decoder(z_fixed)
  c = c.detach()
  save_example('MNISTrec_noise_epoch_%d.png' % epoch ,make_grid((c*0.5+0.5).cpu(),8))
  save_example('MNISTrec_epoch_%d.png' % epoch,make_grid((b*0.5+0.5).cpu(),8))
  torch.save(generator.state_dict(), os.path.join(save_checkpoint, 'gen_'+str(epoch)+'.pt'))
  torch.save(discrim.state_dict(), os.path.join(save_checkpoint, 'disc_'+str(epoch)+'.pt'))