In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt
from tqdm import tqdm

# CIFAR-10 데이터셋 로드 및 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # CIFAR-10은 [0,1] 범위로 정규화
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

# Autoencoder 모델 구현
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        # 인코더
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        
        # 디코더
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # [0, 1] 범위로 출력
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


# 모델 학습
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 모델 학습
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data in tqdm(trainloader):
        images, _ = data
        noisy_images = images + 0.1 * torch.randn_like(images)  # 노이즈 추가
        noisy_images = torch.clip(noisy_images, 0., 1.)

        # 데이터를 GPU로 이동
        images, noisy_images = images.to(device), noisy_images.to(device)
        
        # 모델 예측
        outputs = model(noisy_images)
        
        # 손실 계산
        loss = criterion(outputs, images)
        
        # 역전파
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader)}")


# CIFAR-10의 복원된 이미지를 새로운 데이터셋으로 저장하기 위해 Dataset 클래스를 확장
class RestoredCIFAR10(Dataset):
    def __init__(self, original_dataset, model, transform=None):
        self.original_dataset = original_dataset
        self.model = model
        self.transform = transform
    
    def __len__(self):
        return len(self.original_dataset)
    
    def __getitem__(self, idx):
        image, label = self.original_dataset[idx]
        
        # 모델을 통해 복원
        with torch.no_grad():
            restored_image = self.model(image.unsqueeze(0).to(device)).cpu().squeeze(0)
        
        if self.transform:
            restored_image = self.transform(restored_image)
            image = self.transform(image)
        
        return restored_image, label


# 새로운 데이터셋 생성
restored_trainset = RestoredCIFAR10(trainset, model, transform=transform)
restored_trainloader = DataLoader(restored_trainset, batch_size=64, shuffle=True)

restored_testset = RestoredCIFAR10(testset, model, transform=transform)
restored_testloader = DataLoader(restored_testset, batch_size=64, shuffle=False)




Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 782/782 [01:00<00:00, 12.87it/s]


Epoch [1/10], Loss: 0.2583478016735953


100%|██████████| 782/782 [01:00<00:00, 13.03it/s]


Epoch [2/10], Loss: 0.25599657917571494


100%|██████████| 782/782 [00:59<00:00, 13.04it/s]


Epoch [3/10], Loss: 0.2559980327820839


100%|██████████| 782/782 [00:59<00:00, 13.11it/s]


Epoch [4/10], Loss: 0.25593785078400544


100%|██████████| 782/782 [00:59<00:00, 13.09it/s]


Epoch [5/10], Loss: 0.2559817931650545


100%|██████████| 782/782 [01:00<00:00, 13.02it/s]


Epoch [6/10], Loss: 0.25597226747390256


100%|██████████| 782/782 [00:59<00:00, 13.10it/s]


Epoch [7/10], Loss: 0.25598161496088634


100%|██████████| 782/782 [00:58<00:00, 13.40it/s]


Epoch [8/10], Loss: 0.25600550451394544


100%|██████████| 782/782 [00:57<00:00, 13.50it/s]


Epoch [9/10], Loss: 0.25597149318517626


100%|██████████| 782/782 [01:00<00:00, 12.99it/s]

Epoch [10/10], Loss: 0.25596903910493607





In [5]:
def show_images(original, restored):
    fig, ax = plt.subplots(1, 2, figsize=(12, 4))

    # 원본 이미지를 NumPy 배열로 변환 후 시각화
    ax[0].imshow(original.permute(1, 2, 0).cpu().detach().numpy())
    ax[0].set_title('Original Image')

    # 복원된 이미지를 NumPy 배열로 변환 후 시각화
    ax[1].imshow(restored.permute(1, 2, 0).cpu().detach().numpy())
    ax[1].set_title('Restored Image')

    plt.show()

# 첫 번째 배치의 이미지를 시각화
model.eval()
with torch.no_grad():
    data_iter = iter(restored_trainloader)
    original_images, _ = next(data_iter)
    
    # 첫 번째 배치에서 원본 이미지와 복원된 이미지를 비교
    restored_images = original_images  # 복원된 이미지가 이미 RestoredCIFAR10에서 반환됨
    show_images(original_images[0], restored_images[0])  # 첫 번째 이미지를 비교

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>