In [None]:
import torch
import torchvision
import ignite

print(*map(lambda m: ": ".join((m.__name__, m.__version__)), (torch, torchvision, ignite)), sep="\n")

In [None]:
import os
import logging
import matplotlib.pyplot as plt

import cv2
import numpy as np

from torchsummary import summary

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torchvision.utils as vutils

from ignite.engine import Engine, Events
import ignite.distributed as idist

In [None]:
ignite.utils.manual_seed(999)
ignite.utils.setup_logger(name="ignite.distributed.auto.auto_dataloader", level=logging.WARNING)
ignite.utils.setup_logger(name="ignite.distributed.launcher.Parallel", level=logging.WARNING)

In [None]:
image_size = 256

data_transform = transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),
    ]
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_dataset = dset.ImageFolder(root="datasets/HighResolution/FLIR", transform=data_transform)
test_dataset = torch.utils.data.Subset(train_dataset, torch.arange(3000))

In [None]:
batch_size = 9

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    num_workers=8, 
    shuffle=True, 
    drop_last=True,
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    num_workers=8, 
    shuffle=False, 
    drop_last=True,
)

In [None]:
real_batch = next(iter(train_dataloader))

plt.figure(figsize=(20,20))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:4], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
print(real_batch[0][0].shape)

In [None]:
latent_dim = 256

In [None]:
class depthwise_conv(nn.Module):
    def __init__(self, nin, kernels_per_layer):
        super(depthwise_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=3, padding=1, groups=nin)


    def forward(self, x):
        out = self.depthwise(x)
        return out

In [None]:
class pointwise_conv(nn.Module):
    def __init__(self, nin, nout):
        super(pointwise_conv, self).__init__()
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.pointwise(x)
        return out

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 2048, kernel_size=4, stride=2, padding=1, bias=False),
            # depthwise_conv(1024, 1),
            nn.InstanceNorm2d(2048),
            nn.SELU(True), # 512 x 2 x 2

            depthwise_conv(2048, 1),
            nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(1024),
            nn.SELU(True),
            # state size. 512 x 4 x 4

            depthwise_conv(1024, 1),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(512),
            nn.SELU(True),
            # state size. 512 x 8 x 8

            depthwise_conv(512, 1),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(256),
            nn.SELU(True),
            # state size. 512 x 16 x 16

            depthwise_conv(256, 1),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            depthwise_conv(128, 1),
            nn.InstanceNorm2d(128),
            nn.SELU(True),
            # state size. 512 x 32 x 32            
            
            depthwise_conv(128, 1),
            nn.ConvTranspose2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            depthwise_conv(256, 1),
            nn.InstanceNorm2d(256),
            nn.SELU(True),
            # state size. 512 x 64 x 64

            depthwise_conv(256, 1),
            nn.ConvTranspose2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            depthwise_conv(512, 1),
            nn.InstanceNorm2d(512),
            nn.SELU(True),
            # state size. 512 x 128 x 128

            depthwise_conv(512, 1),
            nn.ConvTranspose2d(512, 1, kernel_size=4, stride=2, padding=1, bias=False),
            nn.SELU(True),
            nn.Tanh(),
            # state size. 1 x 256 x 256
        )

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
netG = Generator(latent_dim)
netG = netG.to(device)

In [None]:
# torch.cuda.set_device(0)  # 0번 GPU를 사용하도록 설정
# idist.device()

In [None]:
# summary(netG, (latent_dim, 1, 1))

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # input is 1 x 256
            nn.Conv2d(1, 8, 3, 2, 1, bias=False),
            depthwise_conv(8, 1),  # Depthwise Convolution
            nn.SiLU(inplace=True),
            # input is 128

            nn.Conv2d(8, 16, 3, 2, 1, bias=False),
            depthwise_conv(16, 1),  # Depthwise Convolution
            nn.InstanceNorm2d(16),
            nn.SiLU(inplace=True),
            # input is 64

            nn.Conv2d(16, 32, 3, 2, 1, bias=False),
            depthwise_conv(32, 1),  # Depthwise Convolution
            nn.InstanceNorm2d(32),
            nn.SiLU(inplace=True),
            # input is 32

            nn.Conv2d(32, 64, 3, 2, 1, bias=False),
            depthwise_conv(64, 1),  # Depthwise Convolution
            nn.InstanceNorm2d(64),
            nn.SiLU(inplace=True),
            # input is 16

            nn.Conv2d(64, 128, 3, 2, 1, bias=False),
            depthwise_conv(128, 1),  # Depthwise Convolution
            nn.InstanceNorm2d(128),
            nn.SiLU(inplace=True),
            # input is 8

            nn.Conv2d(128, 256, 3, 2, 1, bias=False),
            depthwise_conv(256, 1),  # Depthwise Convolution
            nn.InstanceNorm2d(256),
            nn.SiLU(inplace=True),
            # input is 4

            nn.AdaptiveAvgPool2d((1,1)),  # 평탄화를 위해 전역 평균 풀링 사용

            nn.Flatten(),  # Flatten the output
            nn.Linear(256, 1),  # Final Linear layer to reduce to 1 output
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
netD = Discriminator()
netD = netD.to(device)
# summary(netD, (1, 256, 256))

In [None]:
criterionL1 = nn.MSELoss()
criterionL2 = nn.BCELoss()

In [None]:
fixed_noise = torch.randn(10, latent_dim, 1, 1, device=device)

In [None]:
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
real_label = 1
fake_label = 0


def training_step(engine, data):
    netG.train()
    netD.train()

    netD.zero_grad()

    # 실제 이미지에 대한 판별자 손실 계산
    real = data[0].to(device)
    b_size = real.size(0)
    real_label_tensor = torch.full((b_size,), real_label, dtype=torch.float, device=device)
    fake_label_tensor = torch.full((b_size,), fake_label, dtype=torch.float, device=device)

    output_real = netD(real).view(-1)
    errD_real_BCE = criterionL2(output_real, real_label_tensor)
    errD_real_MSE = criterionL1(output_real, real_label_tensor)
    errD_real = 0.4 * errD_real_BCE + 0.6 * errD_real_MSE
    errD_real.backward()

    # 가짜 이미지에 대한 판별자 손실 계산
    noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
    fake = netG(noise)
    output_fake = netD(fake.detach()).view(-1)
    errD_fake_BCE = criterionL2(output_fake, fake_label_tensor)
    errD_fake_MSE = criterionL1(output_fake, fake_label_tensor)
    errD_fake = 0.4 * errD_fake_BCE + 0.6 * errD_fake_MSE
    errD_fake.backward()

    errD = errD_real + errD_fake
    optimizerD.step()

    # 생성자 손실 계산
    netG.zero_grad()
    output = netD(fake).view(-1)
    errG_BCE = criterionL2(output, real_label_tensor)
    errG_MSE = criterionL1(output, real_label_tensor)
    errG = 0.4 * errG_BCE + 0.6 * errG_MSE
    errG.backward()
    optimizerG.step()

    return {
        "Total_Loss_D": errD.item(),
        "Loss_D_real_BCE": errD_real_BCE.item(),
        "Loss_D_real_MSE": errD_real_MSE.item(),
        "Loss_D_fake_BCE": errD_fake_BCE.item(),
        "Loss_D_fake_MSE": errD_fake_MSE.item(),
        "Total_Loss_G": errG.item(),
        "Loss_G_BCE": errG_BCE.item(),
        "Loss_G_MSE": errG_MSE.item(),
        "D_x": output_real.mean().item(),
        "D_G_z1": output_fake.mean().item(),
        "D_G_z2": output.mean().item(),
    }


In [None]:
trainer = Engine(training_step)

In [None]:
def initialize_fn(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
@trainer.on(Events.STARTED)
def init_weights():
    netD.apply(initialize_fn)
    netG.apply(initialize_fn)

In [None]:
G_losses = []
D_losses = []

@trainer.on(Events.ITERATION_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    G_losses.append(o["Total_Loss_G"])
    D_losses.append(o["Total_Loss_D"])

In [None]:
img_list = []

@trainer.on(Events.ITERATION_COMPLETED(every=500))
def store_images(engine):
    with torch.no_grad():
        fake = netG(fixed_noise).cpu()
    img_list.append(fake)

In [None]:
# from ignite.metrics import FID, InceptionScore

In [None]:
# fid_metric = FID(device=idist.device())

In [None]:
# is_metric = InceptionScore(device=idist.device(), output_transform=lambda x: x[0])

In [None]:
import PIL.Image as Image


def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img)
        arr.append(transforms.ToTensor()(pil_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    netG.eval()
    with torch.no_grad():
        netG.to(device)
        noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
        fake_batch = netG(noise).to(device)
        fake = interpolate(fake_batch)
        # print(fake.shape)
        real = interpolate(batch[0])
        # print(real.shape)
        return fake, real

In [None]:
evaluator = Engine(evaluation_step)
# fid_metric.attach(evaluator, "fid")
# is_metric.attach(evaluator, "is")

In [None]:
fid_values = []
is_values = []


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(test_dataloader,max_epochs=1)
    # metrics = evaluator.state.metrics
    # fid_score = metrics['fid']
    # is_score = metrics['is']
    # fid_values.append(fid_score)
    # is_values.append(is_score)
    # print(f"Epoch [{engine.state.epoch}/100] Metric Scores")
    # print(f"*   FID : {fid_score:4f}")
    # print(f"*    IS : {is_score:4f}")

    with torch.no_grad():
        fake = netG(fixed_noise).to(device)

    epoch_number = engine.state.epoch
    image_filename = f"img/FLIR_DCGAN_BASE_BCE_MSE_I/{epoch_number}_epoch.png"
    save_image(fake.data[:10], image_filename, nrow=5, normalize=True)

    # img = cv2.imread(image_filename)
    # plt.figure(figsize=(30, 30))
    # plt.imshow(img, interpolation='nearest')
    # plt.axis('off')
    # plt.show()


In [None]:
from ignite.metrics import RunningAverage

RunningAverage(output_transform=lambda x: x["Total_Loss_G"]).attach(trainer, 'Loss_G')
RunningAverage(output_transform=lambda x: x["Total_Loss_D"]).attach(trainer, 'Loss_D')

In [None]:
from ignite.contrib.handlers import ProgressBar

ProgressBar().attach(trainer, metric_names=['Loss_G','Loss_D'])
ProgressBar().attach(evaluator)

In [None]:
def training(*args):
    trainer.run(train_dataloader, max_epochs=200)

In [None]:
real_batch = next(iter(train_dataloader))
print(real_batch[0].shape)

In [None]:
# with idist.Parallel(backend='nccl') as parallel:
    # parallel.run(training)
training()

# 둘의 Loss가 합이 3을 안넘는게 좋네? 
# Discriminator는 Filter 갯수를 적게 (시작: 4 ~ 8)
# Generator는 Filter 갯수를 많이 (시작: 2048 ~ 4096)

In [None]:
%matplotlib inline 

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations") 
plt.ylabel("Loss")
plt.legend()

plt.savefig('DCGAN_BCE_MSE_I_loss.png')
plt.show()

In [None]:
# fig, ax1 = plt.subplots()

# plt.title("Evaluation Metric During Training")

# color = 'tab:red'
# ax1.set_xlabel('epochs')
# ax1.set_ylabel('IS', color=color)
# ax1.plot(is_values, color=color)

# ax2 = ax1.twinx()

# color = 'tab:blue'
# ax2.set_ylabel('FID', color=color)
# ax2.plot(fid_values, color=color)

# fig.tight_layout()

# fig.savefig('DCGAN_evaluation_metric.png')
# plt.show()


In [None]:
%matplotlib inline

# Grab a batch of real images from the dataloader
real_batch = next(iter(train_dataloader))

# Plot the real images
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:9], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(vutils.make_grid(img_list[-1], padding=2, normalize=True).cpu(),(1,2,0)))


plt.savefig('DCGAN_BCE_MSE_I_images_comparison.png')
plt.show()