In [1]:
import torch
import torchvision
import ignite

print(*map(lambda m: ": ".join((m.__name__, m.__version__)), (torch, torchvision, ignite)), sep="\n")

torch: 2.2.0+cu118
torchvision: 0.17.0+cu118
ignite: 0.4.13


In [2]:
import os
import logging
import matplotlib.pyplot as plt

import cv2
import numpy as np

from torchsummary import summary

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.autograd import Variable
import torch.autograd as autograd
from ignite.engine import Engine, Events
import ignite.distributed as idist


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [3]:
ignite.utils.manual_seed(999)
ignite.utils.setup_logger(name="ignite.distributed.auto.auto_dataloader", level=logging.WARNING)
ignite.utils.setup_logger(name="ignite.distributed.launcher.Parallel", level=logging.WARNING)



In [4]:
class Option():
    n_epochs = 180   # 훈련할 에포크 수
    batch_size = 9  # 배치의 크기
    lr = 0.0002      # Adam 옵티마이저의 학습률
    b1 = 0.5         # Adam 옵티마이저의 그래디언트의 일차 모멘텀 감쇠
    b2 = 0.999       # Adam 옵티마이저의 그래디언트의 이차 모멘텀 감쇠
    n_cpu = 16        # 배치 생성 중에 사용할 CPU 스레드 수
    latent_dim = 200 # 잠재 공간의 차원
    img_size = 256    # 각 이미지 차원의 크기
    channels = 3     # 이미지 채널 수
    sample_interval = 500  # 이미지 샘플링 간격

opt = Option()
img_shape = (opt.channels, opt.img_size, opt.img_size)

In [5]:
data_transform = transforms.Compose(
    [
        transforms.Resize(opt.img_size),
        transforms.CenterCrop(opt.img_size),
        # transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    ]
)
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
train_dataset = dset.ImageFolder(root="datasets/HighResolution/FLIR", transform=data_transform)
test_dataset = torch.utils.data.Subset(train_dataset, torch.arange(3000))

In [6]:
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=opt.batch_size, 
    num_workers=8, 
    shuffle=True, 
    drop_last=True,
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=opt.batch_size, 
    num_workers=8, 
    shuffle=False, 
    drop_last=True,
)

In [7]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels=latent_dim, out_channels=1024, 
                               kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(1024),
            nn.SELU(True),
            # state size. 1024 x 4 x 4

            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.SELU(True),
            # state size. 512 x 8 x 8

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.SELU(True),
            # state size. 256 x 16 x 16

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.SELU(True),
            # state size. 128 x 32 x 32

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.SELU(True),
            # state size. 64 x 64 x 64            

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.SELU(True),
            # state size. 1 x 128 x 128

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
            # state size. 1 x 256 x 256
        )

    def forward(self, x):
        x = self.model(x)
        return x

In [8]:
generator = Generator(opt.latent_dim)
generator = generator.to(device)

In [9]:
# summary(generator, (opt.latent_dim, 1, 1))

In [10]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # 인코더 부분
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # 256x256 -> 128x128
            nn.SELU(True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 128x128 -> 64x64
            nn.SELU(True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 64x64 -> 32x32
            nn.SELU(True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.SELU(True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),  # 16x16 -> 8x8
            nn.SELU(True),
        )
        
        # 임베딩 레이어
        self.embedding = nn.Linear(512 * 8 * 8, 2048)
        
        self.decoder_input_layer = nn.Linear(2048, 512 * 8 * 8)
        # 디코더 부분
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # 8x8 -> 16x16
            nn.SELU(True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 16x16 -> 32x32
            nn.SELU(True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 32x32 -> 64x64
            nn.SELU(True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 64x64 -> 128x128
            nn.SELU(True),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),  # 128x128 -> 256x256
            nn.SELU(True),
        )

    def forward(self, img):
        encoded = self.encoder(img)
        encoded_flat = encoded.view(encoded.size(0), -1)
        embedding = self.embedding(encoded_flat)
        decoder_input = self.decoder_input_layer(embedding)
        decoder_input = decoder_input.view(encoded.size(0), 512, 8, 8)
        decoded = self.decoder(decoder_input)
        return decoded, embedding


In [11]:
discriminator = Discriminator()
discriminator = discriminator.to(device)
# summary(discriminator, (1, 256, 256))

In [12]:
pixelwise_loss = nn.MSELoss()

pixelwise_loss.to(device=device)

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if device else torch.FloatTensor


In [13]:
def pullaway_loss(embeddings):
    # 각 임베딩의 노름(norm) 계산
    norm = torch.sqrt(torch.sum(embeddings ** 2, -1, keepdim=True))
    # 임베딩 정규화
    normalized_emb = embeddings / norm
    # 정규화된 임베딩 간의 유사도 계산
    similarity = torch.matmul(normalized_emb, normalized_emb.transpose(1, 0))
    batch_size = embeddings.size(0)
    # 유사도 행렬에서 대각선을 제외한 모든 요소의 합을 계산하고, 이를 통해 pull away loss 계산
    loss_pt = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1))
    return loss_pt

In [14]:
lambda_pt = 0.1
margin = max(1, opt.batch_size / 64.0)

def training_step(engine, data):
    generator.train()
    discriminator.train()

    real_imgs = data[0].to(device)
    b_size = real_imgs.size(0)

    optimizer_G.zero_grad()
    z = torch.randn(b_size, opt.latent_dim, 1, 1, device=device)
    gen_imgs = generator(z)
    recon_imgs, img_embeddings = discriminator(gen_imgs)

    g_loss = pixelwise_loss(recon_imgs, gen_imgs.detach()) + lambda_pt * pullaway_loss(img_embeddings)
    g_loss.backward()
    optimizer_G.step()


    optimizer_D.zero_grad()
    real_recon, _ = discriminator(real_imgs)
    fake_recon, _ = discriminator(gen_imgs.detach())
    
    d_loss_real = pixelwise_loss(real_recon, real_imgs)
    d_loss_fake = pixelwise_loss(fake_recon, gen_imgs.detach())

    d_loss = d_loss_real
    if (margin - d_loss_fake.data).item() > 0:
        d_loss += margin - d_loss_fake

    d_loss.backward()
    optimizer_D.step()


    return {
        "Loss_G" : g_loss.item(),
        "Loss_D" : d_loss.item(),
    }



In [15]:
trainer = Engine(training_step)

In [16]:
@trainer.on(Events.STARTED)
def init_weights():
    discriminator.apply(weights_init_normal)
    generator.apply(weights_init_normal)

In [17]:
G_losses = []
D_losses = []

@trainer.on(Events.ITERATION_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    G_losses.append(o["Loss_G"])
    D_losses.append(o["Loss_D"])

In [18]:
fixed_noise = torch.randn(10, opt.latent_dim, 1, 1, device=device)

In [19]:
img_list = []

@trainer.on(Events.ITERATION_COMPLETED(every=500))
def store_images(engine):
    with torch.no_grad():
        fake = generator(fixed_noise).cpu()
    img_list.append(fake)

In [20]:
import PIL.Image as Image


def interpolate(batch):
    arr = []
    for img in batch:
        pil_img = transforms.ToPILImage()(img)
        arr.append(transforms.ToTensor()(pil_img))
    return torch.stack(arr)


def evaluation_step(engine, batch):
    with torch.no_grad():
        noise = torch.randn(opt.batch_size, opt.latent_dim, 1, 1, device=device)
        generator.eval().to(device=device)
        fake_batch = generator(noise)
        fake = interpolate(fake_batch)
        # print(fake.shape)
        real = interpolate(batch[0])
        # print(real.shape)
        return fake, real

In [21]:
evaluator = Engine(evaluation_step)

In [22]:
from ignite.metrics import FID, InceptionScore
fid_metric = FID(device=device)
is_metric = InceptionScore(device=device, output_transform=lambda x: x[0])

In [23]:
evaluator = Engine(evaluation_step)
fid_metric.attach(evaluator, "fid")
is_metric.attach(evaluator, "is")

In [24]:
fid_values = []
is_values = []


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(test_dataloader,max_epochs=1)
    metrics = evaluator.state.metrics
    fid_score = metrics['fid']
    is_score = metrics['is']
    fid_values.append(fid_score)
    is_values.append(is_score)
    print(f"Epoch [{engine.state.epoch}/{opt.n_epochs}] Metric Scores")
    print(f"*   FID : {fid_score:4f}")
    print(f"*    IS : {is_score:4f}")

    
    with torch.no_grad():
        fake = generator(fixed_noise).to(device)

    epoch_number = engine.state.epoch
    image_filename = f"img/FLIR_EBGAN/{epoch_number}_epoch.png"
    save_image(fake.data[:10], image_filename, nrow=5, normalize=True)


In [25]:
from ignite.metrics import RunningAverage

RunningAverage(output_transform=lambda x: x["Loss_G"]).attach(trainer, 'Loss_G')
RunningAverage(output_transform=lambda x: x["Loss_D"]).attach(trainer, 'Loss_D')

In [26]:
from ignite.contrib.handlers import ProgressBar

ProgressBar().attach(trainer, metric_names=['Loss_G','Loss_D'])
ProgressBar().attach(evaluator)

  from tqdm.autonotebook import tqdm


In [27]:
def training(*args):
    trainer.run(train_dataloader, max_epochs=opt.n_epochs)

In [28]:
real_batch = next(iter(train_dataloader))
print(real_batch[0].shape)

torch.Size([9, 3, 256, 256])


In [29]:
training()

  return F.mse_loss(input, target, reduction=self.reduction)
                                                                                    

Epoch [1/100] Metric Scores
*   FID : 0.464624
*    IS : 1.707409


  return F.mse_loss(input, target, reduction=self.reduction)
                                                                                    

Epoch [2/100] Metric Scores
*   FID : 0.960117
*    IS : 1.564816


                                                                                   

Epoch [3/100] Metric Scores
*   FID : 0.769778
*    IS : 1.887184


                                                                                   

Epoch [4/100] Metric Scores
*   FID : 0.823062
*    IS : 1.538081


                                                                                   

Epoch [5/100] Metric Scores
*   FID : 0.846349
*    IS : 1.529455


                                                                                    

Epoch [6/100] Metric Scores
*   FID : 0.705635
*    IS : 2.124219


                                                                                   

Epoch [7/100] Metric Scores
*   FID : 0.532145
*    IS : 2.016595


                                                                                   

Epoch [8/100] Metric Scores
*   FID : 0.749553
*    IS : 1.874181


                                                                                        

Epoch [9/100] Metric Scores
*   FID : 0.857552
*    IS : 1.322429


                                                                                     

Epoch [10/100] Metric Scores
*   FID : 0.850596
*    IS : 1.329325


                                                                                     

Epoch [11/100] Metric Scores
*   FID : 0.848965
*    IS : 1.307050


                                                                                     

Epoch [12/100] Metric Scores
*   FID : 0.765052
*    IS : 1.378920


                                                                                     

Epoch [13/100] Metric Scores
*   FID : 0.688591
*    IS : 1.385739


                                                                                     

Epoch [14/100] Metric Scores
*   FID : 0.859589
*    IS : 1.281952


                                                                                     

Epoch [15/100] Metric Scores
*   FID : 0.915492
*    IS : 1.248523


                                                                                     

Epoch [16/100] Metric Scores
*   FID : 0.342982
*    IS : 1.865908


                                                                                     

Epoch [17/100] Metric Scores
*   FID : 0.329499
*    IS : 1.519030


                                                                                     

Epoch [18/100] Metric Scores
*   FID : 0.477110
*    IS : 1.762433


                                                                                     

Epoch [19/100] Metric Scores
*   FID : 0.337249
*    IS : 1.770704


                                                                                     

Epoch [20/100] Metric Scores
*   FID : 0.583976
*    IS : 1.978966


                                                                                     

Epoch [21/100] Metric Scores
*   FID : 0.504738
*    IS : 1.977507


                                                                                     

Epoch [22/100] Metric Scores
*   FID : 0.503136
*    IS : 1.913811


                                                                                     

Epoch [23/100] Metric Scores
*   FID : 0.387531
*    IS : 2.175328


                                                                                     

Epoch [24/100] Metric Scores
*   FID : 0.552491
*    IS : 1.693135


Epoch [25/180]: [574/1737]  33%|███▎      , Loss_G=3.72, Loss_D=0.0282 [00:28<00:53]

In [None]:
%matplotlib inline 

plt.figure(figsize=(10,5))
plt.title("Genera or and Discriminator Los s During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations") 
plt.ylabel("Loss")
plt.legend()

plt.savefig('EBGAN_training_loss.png')
plt.show()

In [None]:
fig, ax1 = plt.subplots()

plt.title("Evaluation Metric During Training")

color = 'tab:red'
ax1.set_xlabel('epochs')
ax1.set_ylabel('IS', color=color)
ax1.plot(is_values, color=color)

ax2 = ax1.twinx()

color = 'tab:blue'
ax2.set_ylabel('FID', color=color)
ax2.plot(fid_values, color=color)

fig.tight_layout()

fig.savefig('EBGAN_evaluation_metric.png')
plt.show()

In [None]:
%matplotlib inline

# Grab a batch of real images from the dataloader
real_batch = next(iter(train_dataloader))

# Plot the real images
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:9], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(vutils.make_grid(img_list[-1], padding=2, normalize=True).cpu(),(1,2,0)))

plt.savefig('EBGAN_real_fake_images_comparison.png')
plt.show()