In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# CIFAR-10 데이터셋 로드
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# 이미지를 오염시키는 함수
def add_noise(images, noise_factor=0.5):
    noisy_images = images + noise_factor * torch.randn_like(images)
    noisy_images = torch.clamp(noisy_images, 0., 1.)  # 픽셀 값이 0과 1 사이로 유지되도록 클리핑
    return noisy_images

# 훈련 데이터에 노이즈 추가
data_iter = iter(trainloader)
images, labels = data_iter.next()
noisy_images = add_noise(images)

# 오염된 이미지 시각화
fig, axes = plt.subplots(1, 5, figsize=(10, 5))
for i in range(5):
    axes[i].imshow(np.transpose(images[i].numpy(), (1, 2, 0)))
    axes[i].set_title("Original")
    axes[i].axis('off')

plt.show()

fig, axes = plt.subplots(1, 5, figsize=(10, 5))
for i in range(5):
    axes[i].imshow(np.transpose(noisy_images[i].numpy(), (1, 2, 0)))
    axes[i].set_title("Noisy")
    axes[i].axis('off')

plt.show()


In [None]:
class DnCNN(nn.Module):
    def __init__(self, num_channels=3, num_filters=64, num_layers=17):
        super(DnCNN, self).__init__()
        
        layers = []
        layers.append(nn.Conv2d(num_channels, num_filters, kernel_size=3, stride=1, padding=1))
        layers.append(nn.ReLU(inplace=True))

        # 여러 개의 잔차 블록 추가
        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(num_filters, num_filters, kernel_size=3, stride=1, padding=1))
            layers.append(nn.ReLU(inplace=True))

        layers.append(nn.Conv2d(num_filters, num_channels, kernel_size=3, stride=1, padding=1))
        
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


In [None]:
# 모델 초기화
model = DnCNN().cuda()

# 손실 함수 및 옵티마이저 정의
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 훈련 루프
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for imgs, _ in trainloader:
        imgs = imgs.cuda()
        noisy_imgs = add_noise(imgs).cuda()

        # 순전파
        optimizer.zero_grad()
        denoised_imgs = model(noisy_imgs)
        
        # 손실 계산
        loss = criterion(denoised_imgs, imgs)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(trainloader)}')


In [None]:
# 훈련된 모델로 복원
model.eval()
with torch.no_grad():
    noisy_img = noisy_images.cuda()
    restored_img = model(noisy_img)

# 시각화
fig, axes = plt.subplots(1, 5, figsize=(10, 5))
for i in range(5):
    axes[i].imshow(np.transpose(restored_img[i].cpu().numpy(), (1, 2, 0)))
    axes[i].set_title("Restored by DNCNN")
    axes[i].axis('off')

plt.show()
