In [None]:

from google.colab import drive
drive.mount('/content/drive', force_remount = True)
# !pip install opencv-python-headless==4.5.2.52

Mounted at /content/drive


In [None]:
"""
Импортируем необходимые библиотеки для создания генеративно-состязательной сети
Код разработан в основном с использованием библиотеки PyTorch
"""
import time
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.io import read_image, write_png
from torchvision.transforms.functional import pil_to_tensor
from torchvision.utils import save_image
from torch.nn.utils import spectral_norm
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.transforms as T
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from random import triangular
from tqdm import tqdm

In [None]:
from torch.cuda.random import Tensor
batch_size = 24

def rgba_loader(path) -> Tensor:
    with open(path, 'rb') as f:
        img = Image.open(f)
        
        return (pil_to_tensor(img.convert('RGBA')).float() / 255 * 2) - 1

dataset = ImageFolder(f"/content/drive/MyDrive/SkinsTrimmed", loader = rgba_loader)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
"""
Определяем, доступны ли какие-либо графические процессоры
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
"""
Сетевые архитектуры
Ниже приведены архитектуры дискриминатора и генератора
"""
class DisConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
      super(DisConvBlock, self).__init__()
      self.conv = spectral_norm(torch.nn.Conv2d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=kernel_size, padding=padding,
                                  bias=False))
      self.norm = nn.BatchNorm2d(out_channels)
      self.act = nn.LeakyReLU(0.1)
      self.drop = nn.Dropout(0.5)

    def forward(self, x):
      x = self.conv(x)
      x = self.act(x)
      x = self.drop(x)
      return x

class GenConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
      super(GenConvBlock, self).__init__()
      self.conv = torch.nn.ConvTranspose2d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=kernel_size, padding = padding,
                                  bias=True)
      self.norm = nn.BatchNorm2d(out_channels)
      self.act = nn.ReLU(0.1)
      self.drop = nn.Dropout(0.5)

    def forward(self, x):
      x = self.conv(x)
      x = self.norm(x)
      x = self.act(x)
      x = self.drop(x)
      return x

class SelfAttention(nn.Module):
  def __init__(self, in_channels):
      super(SelfAttention, self).__init__()
      self.f = torch.nn.Conv2d(in_channels, in_channels//8, 1)
      self.g = torch.nn.Conv2d(in_channels, in_channels//8, 1)
      self.h = torch.nn.Conv2d(in_channels, in_channels, 1)
      self.gamma = nn.Parameter(torch.zeros(1))
      self.softmax  = nn.Softmax(dim=-1)
      
  def forward(self, x):
    batch_size, C, width, height = x.size()
    proj_query  = self.f(x).view(batch_size,-1,width*height).permute(0, 2, 1)
    proj_key =  self.g(x).view(batch_size,-1,width*height)
    energy =  torch.bmm(proj_query,proj_key)
    attention = self.softmax(energy)
    proj_value = self.h(x).view(batch_size,-1,width*height)

    out = torch.bmm(proj_value, attention)
    out = out.view(batch_size,C,width,height)
    
    out = self.gamma*out + x
    return out


class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv_block1 = DisConvBlock(4, 256, 5, 0)
        self.conv_block2 = DisConvBlock(256, 512, 5, 0)
        self.conv_block3 = DisConvBlock(512, 1024, 5, 0)
        self.attn = SelfAttention(1024)
        self.conv1 = spectral_norm(nn.Conv2d(in_channels=64, out_channels=1,
                               kernel_size=52))
        self.act = nn.Sigmoid()
    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.attn(x)
        x = self.conv1(x)
        x = x.view(-1, 1)
        x = self.act(x)
        return x


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.deconv_block1 = GenConvBlock(1, 1024, 5, 2)
        self.deconv_block2 = GenConvBlock(1024, 512, 5, 2)
        self.deconv_block3 = GenConvBlock(512, 256, 5, 2)
        self.attn = SelfAttention(256)
        self.deconv1 = nn.ConvTranspose2d(256, 4, 3, padding=1)
        self.act4 = nn.Tanh()
    def forward(self, x):
        x = self.deconv_block1(x)
        x = self.deconv_block2(x)
        x = self.deconv_block3(x)
        x = self.attn(x)
        x = self.deconv1(x)
        x = self.act4(x)
        return (x)
    def gen_skin(self, filename):
        noise = (torch.rand(1, 1, 64, 64))
        noise = noise.to(device)
        save_image((self.forward(noise)/2 + 0.5)*255, filename)

In [None]:
with open('/content/drive/MyDrive/SkinsGenerator/dis_losses.txt', 'a+') as f:
    for number in trainer.dis_losses:
        f.write(str(number) + "\n")

In [None]:
class Trainer:
  def __init__(self):
    self.discriminator = discriminator().to(device)
    self.generator = generator().to(device)
    self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    self.G_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    self.loss = nn.BCELoss()
    self.dis_losses = []
    self.gen_losses = []
    self.current_epoch = 0
  def load_model(self):
    G_checkpoint = torch.load("/content/drive/MyDrive/SkinsGenerator/Generators/Generator.pth")
    D_checkpoint = torch.load("/content/drive/MyDrive/SkinsGenerator/Discriminators/Discriminator.pth")
    self.discriminator.load_state_dict(D_checkpoint['model_state_dict'])
    self.D_optimizer.load_state_dict(D_checkpoint['optimizer_state_dict'])
    self.generator.load_state_dict(G_checkpoint['model_state_dict'])
    self.G_optimizer.load_state_dict(G_checkpoint['optimizer_state_dict'])
    self.current_epoch = D_checkpoint['epoch']
  def save_losses_to_file(self):
    with open('/content/drive/MyDrive/SkinsGenerator/dis_losses.txt', 'a+') as f:
      for number in self.dis_losses:
          f.write(str(number) + "\n")
      f.close()
    with open('/content/drive/MyDrive/SkinsGenerator/gen_losses.txt', 'a+') as f:
      for number in self.gen_losses:
          f.write(str(number) + "\n")
      f.close()
  def save_control_point(self):
    torch.save({'epoch': epoch,
              'model_state_dict': self.generator.state_dict(),
              'optimizer_state_dict': self.G_optimizer.state_dict()},
              '/content/drive/MyDrive/SkinsGenerator/Generators/Generator.pth')
    torch.save({'epoch': epoch,
              'model_state_dict': self.discriminator.state_dict(),
              'optimizer_state_dict': self.D_optimizer.state_dict()},
              '/content/drive/MyDrive/SkinsGenerator/Discriminators/Discriminator.pth')
    print('Model saved.')
  def train(self, data_loader, crit, numbers_of_epoch, save_frequency):
    torch.set_grad_enabled(True) 
    start_epoch = self.current_epoch
    for p in self.discriminator.parameters():  # reset requires_grad
        p.requires_grad = True  # they are set to False below in netG update
    for epoch in range(start_epoch, start_epoch+numbers_of_epoch):
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, data in progress_bar:
            imgs = data[0]
            # Adversarial ground truths
            valid = Variable((torch.rand((imgs.shape[0], 1))+4)/5, requires_grad=False).to(device)
            fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(device)

            # Configure input
            real_imgs = imgs.to(device)

            z = Variable(Tensor((torch.rand(imgs.shape[0], 1, 64, 64)))).to(device)

            # Generate a batch of images
            gen_imgs = self.generator(z)


            self.D_optimizer.zero_grad()
            real_loss = self.loss(self.discriminator(real_imgs), valid)
            fake_loss = self.loss(self.discriminator(gen_imgs.detach()), fake)
            # ---------------------
            #  Train Discriminator
            # ---------------------
            if (i + 1) % crit == 0:
              for p in self.discriminator.parameters():  # reset requires_grad
                p.requires_grad = True  # they are set to False below in netG update

              # Measure discriminator's ability to classify real from generated samples
              
              real_loss = self.loss(self.discriminator(real_imgs), valid)
              real_loss.backward()
              fake_loss = self.loss(self.discriminator(gen_imgs.detach()), fake)
              fake_loss.backward()
              self.D_optimizer.step()
              #d_loss = (real_loss + fake_loss) / 2
          
            
            for p in self.discriminator.parameters():  # reset requires_grad
              p.requires_grad = False  # they are set to False below in netG update
            # -----------------
            #  Train Generator
            # -----------------

            self.G_optimizer.zero_grad()

            # Sample noise as generator input
            z = Variable(Tensor((torch.rand(imgs.shape[0], 1, 64, 64)))).to(device)

            # Generate a batch of images
            gen_imgs = self.generator(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = self.loss(self.discriminator(gen_imgs), valid)
            g_loss.backward()
            self.G_optimizer.step()

            progress_bar.set_description(f"[{epoch + 1}/{start_epoch+numbers_of_epoch}][{i + 1}/{len(dataloader)}] "
                                          f"Loss_D: {(real_loss + fake_loss) / 2:.6f} Loss_G: {g_loss:.6f} ")
            self.dis_losses.append(float(real_loss + fake_loss))
            self.gen_losses.append(float(g_loss))
            
        if (epoch+1) % save_frequency  == 0:
            self.save_control_point()
            for i in range(9):
              self.generator.gen_skin(f"/content/drive/MyDrive/SkinsGenerator/Generated_imgs/epoch_{epoch+1}_{i}.png")
    self.save_losses_to_file()

In [None]:
trainer = Trainer()

In [None]:
trainer.train(dataloader, 5, 200, 2)

[1/200][415/415] Loss_D: 0.216459 Loss_G: 4.441345 : 100%|██████████| 415/415 [03:14<00:00,  2.13it/s]
[2/200][415/415] Loss_D: 0.199475 Loss_G: 6.055211 : 100%|██████████| 415/415 [03:08<00:00,  2.20it/s]
[3/200][415/415] Loss_D: 0.179310 Loss_G: 5.092997 : 100%|██████████| 415/415 [03:09<00:00,  2.19it/s]
[4/200][415/415] Loss_D: 0.204162 Loss_G: 4.615445 : 100%|██████████| 415/415 [03:09<00:00,  2.19it/s]
[5/200][415/415] Loss_D: 0.181685 Loss_G: 5.016153 : 100%|██████████| 415/415 [03:09<00:00,  2.19it/s]
[6/200][144/415] Loss_D: 0.194663 Loss_G: 5.091354 :  35%|███▍      | 144/415 [01:06<02:04,  2.17it/s]


KeyboardInterrupt: ignored