## Imports

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchinfo import summary

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

import numpy as np

import matplotlib.pyplot as plt

import random
from pathlib import Path
from PIL import Image

from tqdm.notebook import tqdm, trange

## Settings

In [None]:
VARIANT = "A" # Вариант - модели (для разделения моделей с разными гиперпараметрами) result/VARIANT+SEED/
SEED = 451 # Для получения предсказуемых результатов

## Global variables

In [None]:
MODEL_NAME = "selfattn_residual_dcgan"

DEVICE = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)
# DEVICE = torch.device("cpu")
torch.set_default_device(DEVICE)

LATENT_DIM = 100
BATCH_SIZE = 256
EPOCHS = 100
LEARNING_RATE_G = 0.0002
LEARNING_RATE_D = 0.0002
BETA_1 = 0.5
BETA_2 = 0.999
DROPOUT_P = 0.2

# Пути для сохранения
SAVE_DIR = "./result"
DATA_DIR = "./data"

RESULT_DIR = f"{SAVE_DIR}/{MODEL_NAME}/{VARIANT}{SEED}"
GIF_DIR = f"{RESULT_DIR}/gif"

Path(RESULT_DIR).mkdir(parents=True, exist_ok=True)
Path(GIF_DIR).mkdir(parents=True, exist_ok=True)

In [None]:
if str(DEVICE) == 'cuda':
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Utils

In [None]:
def write_log(msg, log=f"{RESULT_DIR}/hist.log"):
    with open(log, "a") as f:
        f.write(msg + "\n")

In [None]:
def save_models():
     """Сохранение моделей"""
     torch.save(generator.state_dict(),
               f"{RESULT_DIR}/generator.pth")
     torch.save(discriminator.state_dict(),
               f"{RESULT_DIR}/discriminator.pth")

In [None]:
def plot_gan_losses(generator_losses, discriminator_losses, model_name="GAN", save_path=None):
    plt.figure(figsize=(10, 6))
    plt.plot(generator_losses, label='Generator loss', color='red')
    plt.plot(discriminator_losses, label='Discriminator loss', color='blue')
    
    plt.title(f'{model_name} losses')
    plt.xlabel('Epochs')
    plt.ylabel('Losses')
    plt.legend()
    plt.grid()

    if save_path:
        plt.savefig(f"{save_path}/losses", bbox_inches='tight')
        print(f"Save path: {save_path}")

    plt.show()

## Info

In [None]:
print(f"device: {DEVICE}")

write_log("="*16)
write_log(f"device: {DEVICE}")
write_log(f"latent_dim: {LATENT_DIM}")
write_log(f"epochs: {EPOCHS}")
write_log(f"learning_rate_G: {LEARNING_RATE_G}")
write_log(f"learning_rate_D: {LEARNING_RATE_D}")
write_log(f"dropout_p: {DROPOUT_P}")
write_log(f"betas: ({BETA_1},{BETA_2})")
write_log(f"seed: {SEED}")
write_log("="*16)

## Dataset

In [None]:
def get_mnist_dataloader(batch_size=None, data_dir='./data'):
    """Загрузка и подготовка датасета MNIST"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = torchvision.datasets.MNIST(
        root=data_dir,
        train=True, 
        download=True, 
        transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        generator=torch.Generator(device=DEVICE),
        drop_last=True
    )

    return train_loader


In [None]:
dataloader = get_mnist_dataloader(batch_size=BATCH_SIZE, data_dir=DATA_DIR)

## Models

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, ch):
        super(SelfAttention, self).__init__()
        self.Wf = nn.Conv2d(ch, ch//2, 1)
        self.Wg = nn.Conv2d(ch, ch//2, 1)
        self.Wh = nn.Conv2d(ch, ch, 1)
        self.gamma = nn.Parameter(torch.full((1, 1), 0.))

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W

        f = self.Wf(x).view(B, -1, N).permute(0, 2, 1) # (B, N, C')
        g = self.Wg(x).view(B, -1, N) # (B, C', N)
        h = self.Wh(x).view(B, -1, N) # (B, C, N)

        s = f@g # (B, N, N)
        beta = torch.softmax(s, dim=-1) # (B, N, N)
        attn = h@beta.permute(0, 2, 1) # (B, C, N)

        out = (self.gamma * attn + x.view(B, C, N)).view(B, C, H, W)

        return out

class ResidualBlock(nn.Module):
    def __init__(self, ch, dropout_p=0.0):
        super(ResidualBlock, self).__init__()
        self.fblock = nn.Sequential(
            nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ch),
            nn.Dropout2d(dropout_p),
            nn.LeakyReLU(0.2),

            nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ch),
            nn.Dropout2d(dropout_p),
        )

    def forward(self, x):
        out = self.fblock(x)
        return out + x

n = 32
class BasicGenerator(nn.Module):
    def __init__(self, latent_dim=100):
        super(BasicGenerator, self).__init__()
        self.model = nn.Sequential(
            
            nn.ConvTranspose2d(latent_dim, n*4, kernel_size=7, stride=1, padding=0, bias=False),
            nn.ReLU(),
            
            nn.ConvTranspose2d(n*4, n*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n*2),
            nn.ReLU(),
            

            ResidualBlock(n*2),
            nn.ReLU(),

            SelfAttention(n*2),
            
            nn.ConvTranspose2d(n*2, n, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n),
            nn.ReLU(),

            ResidualBlock(n),
            nn.ReLU(),

            nn.ConvTranspose2d(n, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.model(x)

class BasicDiscriminator(nn.Module):
    def __init__(self):
        super(BasicDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, n, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Dropout2d(DROPOUT_P),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n, n*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n*2),
            nn.Dropout2d(DROPOUT_P),
            nn.LeakyReLU(0.2),
            
            ResidualBlock(n*2, dropout_p=DROPOUT_P),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(n*2, n*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n*4),
            nn.Dropout2d(DROPOUT_P),
            nn.LeakyReLU(0.2),

            ResidualBlock(n*4, dropout_p=DROPOUT_P),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(n*4, 1, kernel_size=7, stride=1, padding=0),
        )
    
    def forward(self, x):
        return self.model(x)

In [None]:
generator = BasicGenerator(latent_dim=LATENT_DIM)
sm_g = summary(generator, input_size=(1, LATENT_DIM, 1, 1), device=DEVICE)
sm_g

In [None]:
discriminator = BasicDiscriminator()
sm_d = summary(discriminator, input_size=(1, 1, 28, 28), device=DEVICE)
sm_d

In [None]:
write_log(str(sm_g))
write_log(str(sm_d))

In [None]:
#Optimizers
g_optimizer = optim.Adam(
    generator.parameters(),
    lr=LEARNING_RATE_G,
    betas=(BETA_1, BETA_2)
)

d_optimizer = optim.Adam(
    discriminator.parameters(),
    lr=LEARNING_RATE_D,
    betas=(BETA_1, BETA_2)
)

## Training step

In [None]:
criterion = nn.BCEWithLogitsLoss()


In [None]:
def d_step(real_images, batch_size):
    noise = torch.randn(batch_size, LATENT_DIM, 1, 1)

    fake_label = torch.zeros(batch_size, 1, 1, 1)
    real_label = torch.full((batch_size, 1, 1, 1), 0.9).float()

    with torch.no_grad():
        fake_images = generator(noise)

    discriminator.train()
    d_optimizer.zero_grad()

    fake_pred = discriminator(fake_images.detach())
    d_loss_fake = criterion(fake_pred, fake_label)
    
    real_pred = discriminator(real_images)
    d_loss_real = criterion(real_pred, real_label) 
   
    d_loss =  d_loss_fake + d_loss_real 
    
    d_loss.backward()
    d_optimizer.step()

    return d_loss.item()

def g_step(batch_size):
    noise = torch.randn(batch_size, LATENT_DIM, 1, 1)
    
    real_label = torch.ones(batch_size, 1, 1, 1)

    generator.train()
    g_optimizer.zero_grad()

    fake_images = generator(noise)
    fake_pred = discriminator(fake_images)

    g_loss = criterion(fake_pred, real_label)

    g_loss.backward()
    g_optimizer.step()

    return g_loss.item()     

## Training

In [None]:
history_losses_d = []
history_losses_g = []

#
fixed_noise = torch.randn(64, LATENT_DIM, 1, 1).to(DEVICE)
#

for epoch in trange(EPOCHS, unit="epoch"):
    num_batches = 0
    epoch_d_loss = 0
    epoch_g_loss = 0

    
    tq = tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        leave=False,
        unit="batch",)
    
    for i, (real_images, _) in tq :
        real_images = real_images.to(DEVICE)
        batch_size = real_images.size(0)
        epoch_d_loss += d_step(real_images, batch_size)
        epoch_g_loss += g_step(batch_size)
        num_batches += 1

    history_losses_d.append(epoch_d_loss / num_batches)
    history_losses_g.append(epoch_g_loss / (num_batches))
    

    info = f'epoch [{(epoch+1):>3}/{EPOCHS}], ' + \
         f'g_loss: {history_losses_g[-1]:.5f}, ' + \
         f'd_loss: {history_losses_d[-1]:.5f}'

    print(info)

    write_log(info)
    generator.eval() 
    with torch.no_grad():
        gen = generator(fixed_noise)
        save_image(gen.view(gen.size(0), 1, 28, 28),
                  f"{GIF_DIR}/{epoch+1}.png")

## Save model

In [None]:
save_models()

## History

In [None]:
plot_gan_losses(history_losses_g, history_losses_d, MODEL_NAME, RESULT_DIR)