#오토인코더 예제

In [1]:
#라이브러리 호출

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import torchvision.datasets as datasets
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
#MNIST 데이터셋 내려받아 전처리

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='MNIST_data/',
                             train=True,
                             transform=transform,
                             download=True)
test_dataset= datasets.MNIST(root='MNIST_data/',
                             train=False,
                             transform=transform,
                             download=True)
train_loader = DataLoader(dataset= train_dataset,
                          batch_size=100,
                          shuffle=True,
                          drop_last=True)
test_loader = DataLoader(dataset= test_dataset,
                          batch_size=100,
                          shuffle=False,
                          drop_last=True)

In [3]:
#네트워크 신경망 생성

class Encoder(nn.Module):
  def __init__(self, encoded_space_dim, fc2_input_dim):
    super(Encoder, self).__init__()

    self.encoder_cnn = nn.Sequential(
        nn.Conv2d(1,8,3, stride = 2, padding = 1),
        nn.ReLU(True),
        nn.Conv2d(8,16,3, stride = 2, padding = 1),
        nn.BatchNorm2d(16),
        nn.ReLU(True),
        nn.Conv2d(16,32,3, stride = 2, padding = 0),
        nn.ReLU(True)
    ) #이미지 데이터셋 처리를 위해 합성곱 신경망 이용

    self.flatten = nn.Flatten(start_dim=1)
    self.encoder_lin = nn.Sequential(
        nn.Linear(3 * 3 * 32, 128),
        nn.ReLU(True),
        nn.Linear(128, encoded_space_dim)
    )
  def forward(self, x):
    x = self.encoder_cnn(x)
    x = self.flatten(x)
    x = self.encoder_lin(x)
    return x

class Decoder(nn.Module):
  def __init__(self, encoded_space_dim, fc2_input_dim):
    super().__init__()
    self.decoder_lin = nn.Sequential(
        nn.Linear(encoded_space_dim, 128),
        nn.ReLU(True),
        nn.Linear(128, 3 * 3 * 32),
        nn.ReLU(True)
    )
    self.unflatten = nn.Unflatten(dim=1,
                                  unflattened_size=(32, 3, 3))
    self.decoder_conv = nn.Sequential(
        nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
        nn.BatchNorm2d(16),
        nn.ReLU(True),
        nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(8),
        nn.ReLU(True),
        nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)
    ) #인코더의 합성층에 대응

  def forward(self, x):
      x = self.decoder_lin(x)
      x = self.unflatten(x)
      x = self.decoder_conv(x)
      x = torch.sigmoid(x)
      return x


In [4]:
#손실 함수의 옵티마이저 지정

encoder = Encoder(encoded_space_dim=4, fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=4, fc2_input_dim=128)

encoder.to(device)
decoder.to(device)

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
] #인코더와 디코더에서 사용할 파라미터를 다르게 지정

optim = torch.optim.Adam(params_to_optimize, lr=0.001, weight_decay=1e-05)

loss_fn = torch.nn.MSELoss()

In [5]:
import numpy as np

#모델 학습 함수 생성
def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer, noise_factor=0.3):
  encoder.train()
  decoder.train()
  train_loss = []
  for image_batch, _ in dataloader: #훈련 데이터셋으로 모델 학습(비지도 학습)
    image_noisy = add_noise(image_batch, noise_factor)
    image_noisy = image_noisy.to(device)
    encoded_data = encoder(image_noisy)
    decoded_data = decoder(encoded_data)
    loss = loss_fn(decoded_data, image_noisy)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss.append(loss.detach().cpu().numpy())
  return np.mean(train_loss)


In [6]:
#모델 테스트 함수 생성

def test_epoch(encoder, decoder, device, dataloader, loss_fn, noise_factor=0.3):
  encoder.eval()
  decoder.eval()
  with torch.no_grad():
    losses = []
    conc_out = [] #각 배치에 대한 출력을 저장하기 위해 리스트 형식의 변수 정의
    conc_label = []
    for image_batch, _ in dataloader:
      image_batch = image_batch.to(device)
      encoded_data = encoder(image_batch)
      decoded_data = decoder(encoded_data)
      conc_out.append(decoded_data.cpu())
      conc_label.append(image_batch.cpu())
    conc_out = torch.cat(conc_out) #리스트 형식으로 저장된 모든 값을 하나의 텐서로 저장
    conc_label = torch.cat(conc_label)
    val_loss = loss_fn(conc_out, conc_label) #손실 함수를 이용해 오차 계산
  return val_loss.data

In [7]:
#노이즈 데이터 생성

def add_noise(inputs, noise_factor=0.3):
  noisy = inputs + torch.randn_like(inputs) * noise_factor
  noisy = torch.clip(noisy, 0., 1.)
  return noisy

In [None]:
#한글 꺠짐

from matplotlib import font_manager
font_fname = 'C:/Windows/Fonts/malgun.ttf'
font_family = font_manager.FontProperties(fname=font_fname).get_name()
plt.rcParams["font.family"] = font_family

In [8]:
import matplotlib.pyplot as plt
from matplotlib import rc

plt.rcParams['font.family'] = 'NanumGothic'  # Colab 기본 제공


In [9]:
#이미지 시각화

def plot_ae_outputs(encoder, decoder, n=5, noise_factor=0.3):
  plt.figure(figsize=(10,4.5))
  for i in range(n):
    ax = plt.subplot(3, n, i+1)
    img = test_dataset[i][0].unsqueeze(0)
    image_noisy = add_noise(img, noise_factor)
    image_noisy = image_noisy.to(device)

    encoder.eval()
    decoder.eval()

    with torch.no_grad():
      rec_img = decoder(encoder(image_noisy))

    plt.imshow(img.cpu().sqeeze().numpy(), cmap='gist_gray') #테스트 데이터셋 출력
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    if i == n//2:
      ax.set_title('원래 이미지')

    ax = plt.subplot(3, n, i+1+n)
    plt.imshow(image_noisy.cpu().squeeze().numpy(), cmap='gist_gray') #테스트 데이터셋에 노이즈가 적용된 결과
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    if i == n//2:
      ax.set_title('노이즈가 적용되어 손상되 이미지')

    ax = plt.subplot(3, n, i+1+n+n)
    plt.imshow(rec_img.cpu().sqeeze().numpy(), cmap='gist_gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    if i == n//2:
      ax.set_title('재구성된 이미지')
  plt.subplots_adjust(left=0.1, bottom=0.1, right=0.7, top=0.9, wspace=0.3, hspace=0.3)
  plt.show()


In [None]:
#모델 학습

import numpy as np

num_epochs = 30
history_da = {'train_loss':[], 'val_loss':[]}
loss_fn = torch.nn.MSELoss()

for epoch in range(num_epochs):
  print('EPOCH %d/%d' % (epoch + 1, num_epochs))
  train_loss = train_epoch(
      encoder = encoder,
      decoder = decoder,
      device = device,
      dataloader = train_loader,
      loss_fn = loss_fn,
      optimizer = optim, noise_factor=0.3)

  val_loss = test_epoch(
      encoder = encoder,
      decoder = decoder,
      device = device,
      dataloader = test_loader,
      loss_fn = loss_fn, noise_factor=0.3)

  history_da['train_loss'].append(train_loss)
  history_da['val_loss'].append(val_loss)

  print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch+1, num_epochs, train_loss, val_loss))
plot_ae_output(encoder, decoder, noise_factor=0.3)
