In [None]:
import torch
import torchvision
import ignite

print(*map(lambda m: ": ".join((m.__name__, m.__version__)), (torch, torchvision, ignite)), sep="\n")

In [None]:
import os
import logging
import matplotlib.pyplot as plt

import cv2
import numpy as np

from torchsummary import summary

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torchvision.utils as vutils

from ignite.engine import Engine, Events
import ignite.distributed as idist

In [None]:
ignite.utils.manual_seed(999)
ignite.utils.setup_logger(name="ignite.distributed.auto.auto_dataloader", level=logging.WARNING)
ignite.utils.setup_logger(name="ignite.distributed.launcher.Parallel", level=logging.WARNING)

In [None]:
image_size = 256

data_transform = transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        # transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ]
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_dataset = dset.ImageFolder(root="datasets/HighResolution/FLIR", transform=data_transform)
test_dataset = torch.utils.data.Subset(train_dataset, torch.arange(3000))

In [None]:
batch_size = 9

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    num_workers=8, 
    shuffle=True, 
    drop_last=True,
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    num_workers=8, 
    shuffle=False, 
    drop_last=True,
)

In [None]:
real_batch = next(iter(train_dataloader))

plt.figure(figsize=(20,20))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:4], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
print(real_batch[0][0].shape)

In [None]:
latent_dim = 150

In [None]:
# class Generator1x512x512(nn.Module):
#     def __init__(self, latent_dim):
#         super(Generator1x512x512, self).__init__()
#         self.model = nn.Sequential(
#             nn.ConvTranspose2d(in_channels=latent_dim, out_channels=2096, 
#                                kernel_size=4, stride=1, padding=0, bias=False),
#             nn.BatchNorm2d(2096),
#             nn.ReLU(True),
#             # state size. 2096 x 4 x 4
#             nn.ConvTranspose2d(2096, 1024, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(1024),
#             nn.ReLU(True),
#             # state size. 1024 x 8 x 8
#             nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(512),
#             nn.ReLU(True),
#             # state size. 512 x 16 x 16
#             nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(256),
#             nn.ReLU(True),
#             # state size. 256 x 32 x 32
#             nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(128),
#             nn.ReLU(True),
#             # state size. 128 x 64 x 64

#             nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(64),
#             nn.ReLU(True),
#             # state size. 64 x 128 x 128

#             nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(32),
#             nn.ReLU(True),
#             # state size. 32 x 256 x 256

#             nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
#             nn.Tanh(),
#             # state size. 3 x 512 x 512
#         )

#     def forward(self, x):
#         x = self.model(x)
#         return x

In [None]:
class Generator1x512x512(nn.Module):
    def __init__(self, latent_dim):
        super(Generator1x512x512, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels=latent_dim, out_channels=1024, 
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(1024),
            nn.SELU(True),
            # state size. 1024 x 4 x 4

            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.SELU(True),
            # state size. 512 x 8 x 8

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.SELU(True),
            # state size. 256 x 16 x 16

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.SELU(True),
            # state size. 128 x 32 x 32

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.SELU(True),
            # state size. 64 x 64 x 64            

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.SELU(True),
            # state size. 1 x 128 x 128

            # nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(64),
            # nn.ReLU(True),
            # state size. 1 x 256 x 256

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
            # state size. 1 x 256 x 256
        )

    def forward(self, x):
        # print(x.shape)
        x = self.model(x)
        return x

In [None]:
netG = Generator1x512x512(latent_dim)
netG = netG.to(device)

In [None]:
# torch.cuda.set_device(0)  # 0번 GPU를 사용하도록 설정
idist.device()

In [None]:
# summary(netG, (latent_dim, 1, 1))

In [None]:
# class Discriminator1x512x512(nn.Module):
#     def __init__(self):
#         super(Discriminator1x512x512, self).__init__()
#         self.model = nn.Sequential(
#             # input is 3 x 512 x 512
#             nn.Conv2d(3, 64, 4, 2, 1, bias=False),
#             nn.LeakyReLU(0.2, inplace=True),
#             # state size. 64 x 256 x 256

#             nn.Conv2d(64, 128, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(128),
#             nn.LeakyReLU(0.2, inplace=True),
#             # state size. 128 x 128 x 128

#             nn.Conv2d(128, 256, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(256),
#             nn.LeakyReLU(0.2, inplace=True),
#             # state size. 256 x 64 x 64

#             nn.Conv2d(256, 512, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(512),
#             nn.LeakyReLU(0.2, inplace=True),
#             # state size. 512 x 32 x 32

#             nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(1024),
#             nn.LeakyReLU(0.2, inplace=True),
#             # state size. 1024 x 16 x 16

#             nn.Conv2d(1024, 2048, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(2048),
#             nn.LeakyReLU(0.2, inplace=True),
#             # state size. 2048 x 8 x 8

#             nn.Conv2d(2048, 1, 8, 1, 0, bias=False),
#             nn.Sigmoid()
#         )

#     def forward(self, x):
#         x = self.model(x)
#         return x

In [None]:
class Discriminator1x512x512(nn.Module):
    def __init__(self):
        super(Discriminator1x512x512, self).__init__()
        self.model = nn.Sequential(
            # input is 1 x 256
            
            nn.Conv2d(3, 8, 4, 2, 1, bias=False),
            nn.SiLU(inplace = True),
            # input is 128

            nn.Conv2d(8, 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(16),
            # nn.LeakyReLU(0.3, inplace = True),
            nn.SiLU(inplace = True),
            # input is 64

            nn.Conv2d(16, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.SiLU(inplace = True),
            # input is 32

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.SiLU(inplace = True),
            # input is 16

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.SiLU(inplace = True),
            # input is 6

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.SiLU(inplace = True),
            # input is 4

            # nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(512),
            # nn.SiLU(inplace = True),
            # input is 4 x 4

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
netD = Discriminator1x512x512()
netD = netD.to(device)
# summary(netD, (3, 256, 256))

In [None]:
criterion = nn.BCELoss()

In [None]:
fixed_noise = torch.randn(10, latent_dim, 1, 1, device=device)

In [None]:
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
real_label = 1
fake_label = 0


def training_step(engine, data):
    netG.train()
    netD.train()

    netD.zero_grad()

    real = data[0].to(device)
    b_size = real.size(0)
    label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

    output1 = netD(real).view(-1)
    errD_real = criterion(output1, label)
    errD_real.backward()

    noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
    fake = netG(noise)
    label.fill_(fake_label)
    
    output2 = netD(fake.detach()).view(-1)
    errD_fake = criterion(output2, label)
    errD_fake.backward()
    errD = errD_real + errD_fake
    optimizerD.step()


    netG.zero_grad()
    label.fill_(real_label) 
    output3 = netD(fake).view(-1)
    errG = criterion(output3, label)
    errG.backward()
    optimizerG.step()
    
    return {
        "Loss_G" : errG.item(),
        "Loss_D" : errD.item(),
        "D_x": output1.mean().item(),
        "D_G_z1": output2.mean().item(),
        "D_G_z2": output3.mean().item(),
    }

In [None]:
trainer = Engine(training_step)

In [None]:
def initialize_fn(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
@trainer.on(Events.STARTED)
def init_weights():
    netD.apply(initialize_fn)
    netG.apply(initialize_fn)

In [None]:
G_losses = []
D_losses = []

@trainer.on(Events.ITERATION_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    G_losses.append(o["Loss_G"])
    D_losses.append(o["Loss_D"])

In [None]:
img_list = []

@trainer.on(Events.ITERATION_COMPLETED(every=500))
def store_images(engine):
    with torch.no_grad():
        fake = netG(fixed_noise).cpu()
    img_list.append(fake)

In [None]:
from ignite.metrics import FID, InceptionScore

In [None]:
fid_metric = FID(device=idist.device())

In [None]:
is_metric = InceptionScore(device=idist.device(), output_transform=lambda x: x[0])

In [None]:
import PIL.Image as Image


def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img)
        arr.append(transforms.ToTensor()(pil_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    netG.eval()
    with torch.no_grad():
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake_batch = netG(noise).to(device)
        fake = interpolate(fake_batch)
        # print(fake.shape)
        real = interpolate(batch[0])
        # print(real.shape)
        return fake, real

In [None]:
evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")

In [None]:
fid_values = []
is_values = []


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(test_dataloader,max_epochs=1)
    metrics = evaluator.state.metrics
    fid_score = metrics['fid']
    is_score = metrics['is']
    fid_values.append(fid_score)
    is_values.append(is_score)
    print(f"Epoch [{engine.state.epoch}/100] Metric Scores")
    print(f"*   FID : {fid_score:4f}")
    print(f"*    IS : {is_score:4f}")

    with torch.no_grad():
        fake = netG(fixed_noise).to(device)

    epoch_number = engine.state.epoch
    image_filename = f"img/FLIR_DCGAN/{epoch_number}_epoch.png"
    save_image(fake.data[:10], image_filename, nrow=5, normalize=True)

    # img = cv2.imread(image_filename)
    # plt.figure(figsize=(30, 30))
    # plt.imshow(img, interpolation='nearest')
    # plt.axis('off')
    # plt.show()


In [None]:
from ignite.metrics import RunningAverage

RunningAverage(output_transform=lambda x: x["Loss_G"]).attach(trainer, 'Loss_G')
RunningAverage(output_transform=lambda x: x["Loss_D"]).attach(trainer, 'Loss_D')

In [None]:
from ignite.contrib.handlers import ProgressBar

ProgressBar().attach(trainer, metric_names=['Loss_G','Loss_D'])
ProgressBar().attach(evaluator)

In [None]:
def training(*args):
    trainer.run(train_dataloader, max_epochs=180)

In [None]:
real_batch = next(iter(train_dataloader))
print(real_batch[0].shape)

In [None]:
# with idist.Parallel(backend='nccl') as parallel:
    # parallel.run(training)
training()

# 둘의 Loss가 합이 3을 안넘는게 좋네? 
# Discriminator는 Filter 갯수를 적게 (시작: 4 ~ 8)
# Generator는 Filter 갯수를 많이 (시작: 2048 ~ 4096)

In [None]:
%matplotlib inline 
 
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()

plt.savefig('DCGAN_training_loss.png')
plt.show()

In [None]:
fig, ax1 = plt.subplots()

plt.title("Evaluation Metric During Training")

color = 'tab:red'
ax1.set_xlabel('epochs')
ax1.set_ylabel('IS', color=color)
ax1.plot(is_values, color=color)

ax2 = ax1.twinx()

color = 'tab:blue'
ax2.set_ylabel('FID', color=color)
ax2.plot(fid_values, color=color)

fig.tight_layout()

fig.savefig('DCGAN_evaluation_metric.png')
plt.show()


In [None]:
%matplotlib inline

# Grab a batch of real images from the dataloader
real_batch = next(iter(train_dataloader))

# Plot the real images
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:9], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(vutils.make_grid(img_list[-1], padding=2, normalize=True).cpu(),(1,2,0)))


plt.savefig('DCGAN_real_fake_images_comparison.png')
plt.show()