In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# CIFAR-10 데이터셋 로드
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# 이미지를 오염시키는 함수
def add_noise(images, noise_factor=0.5):
    noisy_images = images + noise_factor * torch.randn_like(images)
    noisy_images = torch.clamp(noisy_images, 0., 1.)  # 픽셀 값이 0과 1 사이로 유지되도록 클리핑
    return noisy_images

# 훈련 데이터에 노이즈 추가
data_iter = iter(trainloader)
images, labels = data_iter.next()
noisy_images = add_noise(images)

# 오염된 이미지 시각화
fig, axes = plt.subplots(1, 5, figsize=(10, 5))
for i in range(5):
    axes[i].imshow(np.transpose(images[i].numpy(), (1, 2, 0)))
    axes[i].set_title("Original")
    axes[i].axis('off')

plt.show()

fig, axes = plt.subplots(1, 5, figsize=(10, 5))
for i in range(5):
    axes[i].imshow(np.transpose(noisy_images[i].numpy(), (1, 2, 0)))
    axes[i].set_title("Noisy")
    axes[i].axis('off')

plt.show()


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU(True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        x = self.tanh(self.deconv3(x))
        return x


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.fc = nn.Linear(256 * 4 * 4, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.sigmoid(self.fc(x))
        return x


In [None]:
# 모델 초기화
generator = Generator().cuda()
discriminator = Discriminator().cuda()

# 손실 함수와 옵티마이저 정의
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 훈련 루프
epochs = 10
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(trainloader):
        imgs = imgs.cuda()
        noisy_imgs = add_noise(imgs).cuda()

        # 실제 레이블은 1, 가짜 레이블은 0
        real_labels = torch.ones(imgs.size(0), 1).cuda()
        fake_labels = torch.zeros(imgs.size(0), 1).cuda()

        # -----------------
        # Discriminator 훈련
        # -----------------
        optimizer_d.zero_grad()

        # 실제 이미지에 대해 Discriminator 훈련
        real_outputs = discriminator(imgs)
        d_loss_real = criterion(real_outputs, real_labels)

        # 가짜 이미지에 대해 Discriminator 훈련
        fake_imgs = generator(noisy_imgs)
        fake_outputs = discriminator(fake_imgs.detach())
        d_loss_fake = criterion(fake_outputs, fake_labels)

        # 총 Discriminator 손실
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # -----------------
        # Generator 훈련
        # -----------------
        optimizer_g.zero_grad()

        # Generator는 Discriminator를 속이도록 훈련
        g_loss = criterion(discriminator(fake_imgs), real_labels)
        g_loss.backward()
        optimizer_g.step()

    print(f'Epoch [{epoch+1}/{epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

# 훈련된 Generator로 오염된 이미지 복원
with torch.no_grad():
    noisy_img = noisy_images.cuda()
    restored_img = generator(noisy_img)

# 시각화
fig, axes = plt.subplots(1, 5, figsize=(10, 5))
for i in range(5):
    axes[i].imshow(np.transpose(restored_img[i].cpu().numpy(), (1, 2, 0)))
    axes[i].set_title("Restored by GAN")
    axes[i].axis('off')

plt.show()
