In [None]:
import torch
import torchvision
import ignite

print(*map(lambda m: ": ".join((m.__name__, m.__version__)), (torch, torchvision, ignite)), sep="\n")

In [None]:
import os
import logging
import matplotlib.pyplot as plt

import cv2
import numpy as np

from torchsummary import summary

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.autograd import Variable

from ignite.engine import Engine, Events
import ignite.distributed as idist

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
# ignite.utils.manual_seed(999)
# ignite.utils.setup_logger(name="ignite.distributed.auto.auto_dataloader", level=logging.WARNING)
# ignite.utils.setup_logger(name="ignite.distributed.launcher.Parallel", level=logging.WARNING)

In [None]:
class Option():
    n_epochs = 200   # 훈련할 에포크 수
    batch_size = 9  # 배치의 크기
    lr = 0.0002      # Adam 옵티마이저의 학습률
    b1 = 0.5         # Adam 옵티마이저의 그래디언트의 일차 모멘텀 감쇠
    b2 = 0.999       # Adam 옵티마이저의 그래디언트의 이차 모멘텀 감쇠
    n_cpu = 16        # 배치 생성 중에 사용할 CPU 스레드 수
    latent_dim = 150 # 잠재 공간의 차원
    img_size = 512    # 각 이미지 차원의 크기
    channels = 1     # 이미지 채널 수
    sample_interval = 500  # 이미지 샘플링 간격

opt = Option()

In [None]:
data_transform = transforms.Compose(
    [
        transforms.Resize(opt.img_size),
        transforms.CenterCrop(opt.img_size),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)),
    ]
)
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
train_dataset = dset.ImageFolder(root="../datasets/HighResolution/FLIR", transform=data_transform)
# test_dataset = torch.utils.data.Subset(train_dataset, torch.arange(3000))

In [None]:
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=opt.batch_size, 
    num_workers=opt.n_cpu, 
    shuffle=True, 
    drop_last=True,
)

# test_dataloader = DataLoader(
#     test_dataset, 
#     batch_size=batch_size, 
#     num_workers=8, 
#     shuffle=False, 
#     drop_last=True,
# )

In [None]:
real_batch = next(iter(train_dataloader))

plt.figure(figsize=(20,20))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:4], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
print(real_batch[0][0].shape)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 512 
        self.init_size = opt.img_size // 256 # 2
        self.l1 = nn.Sequential(
            nn.Linear(in_features=opt.latent_dim, out_features=4096 * self.init_size ** 2),  # 512
            )

        self.conv_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(4096, 2048, 3, stride=1, padding=1),  
            nn.InstanceNorm2d(2048, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 4, 4

            nn.Upsample(scale_factor=2),
            nn.Conv2d(2048, 1024, 3, stride=1, padding=1), 
            nn.InstanceNorm2d(1024, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 8

            nn.Upsample(scale_factor=2),
            nn.Conv2d(1024, 512, 3, stride=1, padding=1), 
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 16

            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, stride=1, padding=1), 
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 32

            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1), 
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 64

            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1), 
            nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 128

            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.InstanceNorm2d(32, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 256

            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 16, 3, stride=1, padding=1),
            nn.InstanceNorm2d(16, affine=True),
            nn.LeakyReLU(0.2, inplace=True), # 512

            nn.Conv2d(16, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 4096, self.init_size, self.init_size) # 4096, 2, 2 
        img = self.conv_blocks(out)
        return img

In [None]:
netG = Generator()
netG = netG.to(device)

In [None]:
print(opt.latent_dim)

In [None]:

# summary(netG, input_size=(1, 1, opt.latent_dim), device=device.type)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.InstanceNorm2d(out_filters, True))
            return block
        
        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 8, bn=False), # 256
            *discriminator_block(8, 16),  # 128
            *discriminator_block(16, 32),  # 64
            *discriminator_block(32, 64), # 32
            *discriminator_block(64, 128),# 16
            *discriminator_block(128, 256), # 8
            *discriminator_block(256, 512), # 4
        )
        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 7 # 4
        self.adv_layer = nn.Linear(512 * ds_size ** 2, 1)

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [None]:
netD = Discriminator()
netD = netD.to(device)
# summary(netD, (1, 512, 512))

In [None]:
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if device:
    generator.to(device=device)
    discriminator.to(device=device)
    adversarial_loss.to(device=device)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if device else torch.FloatTensor

In [None]:
# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(train_dataloader):
        imgs = imgs.to(device)

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(device)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(device)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor)).to(device)  

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))).to(device)

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
        % (epoch, opt.n_epochs, i, len(train_dataloader), d_loss.item(), g_loss.item())
    )
    # save_image(gen_imgs.data, "img/FLIR_LSGAN_MF/%d.png" % (epoch+1), nrow=5, normalize=True)
    
    # plt.figure(figsize = (10,10))
    # img1 = cv2.imread("img/FLIR_LSGAN_MF/%d.png" % (epoch+1))
    # plt.imshow(img1, interpolation='nearest')
    # plt.axis('off')
    # plt.show()
