In [None]:
# Основные библиотеки
import numpy as np 
from numpy.random import random
from scipy.linalg import sqrtm
import pandas as pd
import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
import torch
from torch import device
import torch.nn as nn
import cv2
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Датасет с гугл диска
from google.colab import drive
drive.mount('/content/gdrive/')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [None]:
!ls /content/gdrive/MyDrive

 archive.zip   cats  'Colab Notebooks'	 nauk.rar   Timetable.gsheet


In [None]:
direc = '/content/gdrive/MyDrive/cats'

In [None]:
print(os.listdir(direc + '/cats')[:10])


['94.jpg', '9598.jpg', '9661.jpg', '9867.jpg', '9352.jpg', '9513.jpg', '9973.jpg', '9116.jpg', '9110.jpg', '8990.jpg']


In [None]:
# batch = 64
image_size = 64
batch_size = 64
latent_size= 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

train = ImageFolder(direc, transform=tt.Compose([ tt.Resize(image_size),
                                                        tt.CenterCrop(image_size),
                                                        tt.ToTensor(),
                                                        tt.Normalize(*stats)]))
 
train_dl = DataLoader(train, batch_size, shuffle=True, num_workers=3, pin_memory=True) # Загрузка датасета

In [None]:
# Генератор
generator = nn.Sequential(
    # in: latent_size x 1 x 1

    nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 512 x 4 x 4

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # out: 256 x 8 x 8

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # out: 128 x 16 x 16

    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),
    # out: 64 x 32 x 32

    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh())
    # out: 3 x 64 x 64


In [None]:
# Дискриминатор
discriminator = nn.Sequential(
    # in: 3 x 64 x 64

    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 64 x 32 x 32

    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 128 x 16 x 16

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 256 x 8 x 8

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 4 x 4

    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # out: 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid())

In [None]:
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

In [None]:
# Сохранение изображеений во время эпохи

sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

def save_samples(index, latent_tensors, show=True):
    fake_images = generator(latent_tensors).to(device)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

In [None]:
# Обучение дискриминатора
def train_discriminator(real_images, opt_d):
    
    # Обнуление градиентов
    opt_d.zero_grad()

    # Доставка настоящих изображений
    real_preds = discriminator(real_images).to(device) 
    real_targets = torch.ones(real_images.size(0), 1).to(device) 
    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()
    
    # Генерация фейков
    latent = torch.randn(batch_size, latent_size, 1, 1).to(device)
    fake_images = generator(latent).to(device)  

    # Доставка фейков через дискриминатор
    fake_targets = torch.zeros(fake_images.size(0), 1).to(device)
    fake_preds = discriminator(fake_images).to(device)  
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)  
    fake_score = torch.mean(fake_preds).item()

    # Обновление весов
    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()
    return loss.item(), real_score, fake_score

In [None]:
# Обучение генератора
def train_generator(opt_g):

    # Обнуление градиентов
    opt_g.zero_grad()
    
    # Генерация фейков
    latent = torch.randn(batch_size, latent_size, 1,1).to(device) # шум
    fake_images = generator(latent).to(device) 
    
    # "Обман" дискриминатора
    preds = discriminator(fake_images).to(device) # предсказания для фейков
    targets = torch.ones(batch_size, 1).to(device) # цель 1, чтобы обмануть дискриминатор
    loss = F.binary_cross_entropy(preds, targets) # сравнение
    
    # Обновление весов
    loss.backward()
    opt_g.step()
    
    return loss.item(), latent

In [None]:
def fit(epochs, lr, start_idx=1):
    torch.cuda.empty_cache()
    
    # лосс функции (значения)
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []
    
    # Оптимизаторы
    opt_d = torch.optim.Adam(discriminator.to(device).parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.to(device).parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for real_images, _ in tqdm(train_dl):
            
            # Обуччение
            real_images= real_images.to(device)
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
            
            loss_g, latent = train_generator(opt_g)
            
        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
        
        print("Эпоха [{}/{}], loss_g: {}, loss_d: {}, real_score: {}, fake_score: {}".format(
            epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
    
        # Сохранение
        save_samples(epoch+start_idx, latent, show=False)
    
    return losses_g, losses_d, latent, fake_scores

In [None]:
# Обучение модели
# При больших количествах эпох получался mode collapse сети
model = fit(epochs=30, lr=0.0002)

In [None]:
!zip -r ./generated.zip ./generated/

In [None]:
from google.colab import files
files.download("generated.zip")

In [None]:
# Сохранение сгенерированных изображений (10000)

sample_dir = 'generated_full/cats'
os.makedirs(sample_dir, exist_ok=True)

def save_samples_full(index, latent_tensors):
    fake_images = generator(latent_tensors).to(device)
    fake_fname = 'cats-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname))
    print('Saving', fake_fname)

In [None]:
# Генерация фейков
for i in range(10000):
    latent = torch.randn(1, latent_size, 1,1).to(device)
    save_samples_full(i, latent)

In [None]:
!zip -r ./generated_full.zip ./generated_full/cats

In [None]:
from google.colab import files
files.download("generated_full.zip")

In [None]:
!ls

In [None]:
!rm -rf generated
!rm -rf generated.zip

In [None]:
!rm -rf generated_full
!rm -rf generated_full.zip

In [None]:
torch.cuda.empty_cache()