Deep Convolutional GAN

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
import time
import random
import glob
from google.colab import drive
from torchvision.datasets import ImageFolder
from google.colab import files
from IPython.display import display, HTML

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"사용 장치: {device}")


In [None]:
batch_size = 64
image_size = 64
nz = 100  
ngf = 64  
ndf = 64  
num_epochs = 25  # 명확한 이미지를 원할 경우 epochs를 늘리기
lr = 0.0002
beta1 = 0.5

In [None]:
def download_dog_dataset():
    print("강아지 데이터셋 다운로드 중...")
    import torchvision.datasets as datasets

    cifar10 = datasets.CIFAR10(root='./cifar10', download=True, train=True)
    
    class_labels = cifar10.classes  
    print(f"CIFAR-10 클래스 목록: {class_labels}")

    dog_idx = class_labels.index('dog')  
    print(f"강아지 클래스 인덱스: {dog_idx}")
    dog_images = []
    
    for i in range(len(cifar10)):
        img, label = cifar10[i]
        if label == dog_idx:
            dog_images.append(img)
    print(f"{len(dog_images)}개의 강아지 이미지를 추출했습니다.")


In [None]:
# 저장 폴더 생성
    os.makedirs('./dog_dataset/dogs', exist_ok=True)
    # 이미지 저장
    for i, img in enumerate(dog_images):
        img.save(f'./dog_dataset/dogs/dog_{i}.jpg')
    print(f"이미지를 './dog_dataset/dogs/' 폴더에 저장했습니다.")
    return './dog_dataset’  


In [None]:
try:
    if not os.path.exists('./dog_dataset') or len(glob.glob('./dog_dataset/*/*.jpg')) == 0:
        data_root = download_dog_dataset()
    else:
        data_root = './dog_dataset'
        print(f"기존 데이터셋 사용: {data_root}")
except:
    print("데이터셋 확인 중 오류 발생 확인, 재다운로드를 시도합니다.")
    data_root = download_dog_dataset()

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


In [None]:
try:
    dataset = ImageFolder(root=data_root, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    print(f"데이터셋 로드 완료: {len(dataset)} 이미지")
except Exception as e:
    print(f"데이터셋 로드 오류: {e}")
    print("임의 데이터로 코드를 계속 실행합니다...")
    
    def create_random_dataset(num_samples=1000):
        random_data = torch.randn(num_samples, 3, image_size, image_size)
        random_data = torch.clamp((random_data * 0.2) + 0.5, 0, 1) 
        random_dataset = [(img, 0) for img in random_data] 
        return random_dataset


In [None]:
class RandomDataset(torch.utils.data.Dataset):
        def __init__(self, data):
            self.data = data
        
        def __len__(self):
            return len(self.data)
        
        def __getitem__(self, idx):
            return self.data[idx]
    
    random_dataset = RandomDataset(create_random_dataset())
    dataloader = DataLoader(random_dataset, batch_size=batch_size, shuffle=True)
    print(f"임의의 데이터셋 생성 완료: {len(random_dataset)} 이미지")


In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, input):
        return self.main(input)


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)


In [None]:
netG = Generator().to(device)
netD = Discriminator().to(device)
netG.apply(weights_init)
netD.apply(weights_init)
print("생성자 모델 구조:")
print(netG)
print("\n판별자 모델 구조:")
print(netD)


In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)

real_label = 1
fake_label = 0

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))


In [None]:
def visualize_images(images, title=None, display_in_notebook=True):
    images = (images + 1) / 2.0
    
    grid = torchvision.utils.make_grid(images, padding=2, normalize=False)
    img = grid.permute(1, 2, 0).cpu().numpy()
    
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    if title:
        plt.title(title)
    plt.axis('off')
    
    if display_in_notebook:
        plt.show()
    return img

os.makedirs("results", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
G_losses = []
D_losses = []
img_list = []


In [None]:
print("학습 시작...")
def progress_bar(current, total, bar_length=50):
    fraction = current / total
    arrow = int(fraction * bar_length) * '='
    padding = (bar_length - len(arrow)) * ' '
    return f"[{arrow}{padding}] {int(fraction * 100)}%"
for epoch in range(num_epochs):
    start_time = time.time()
    for i, data in enumerate(dataloader, 0):        
       
        netD.zero_grad()
        if isinstance(data, list) and len(data) == 2:  
            real_cpu = data[0].to(device)
        else:  
            real_cpu = data[0].to(device)
            
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)


In [None]:
output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()
        
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        
        errD = errD_real + errD_fake
        optimizerD.step()


In [None]:
netG.zero_grad()
        label.fill_(real_label)  
        
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        
        optimizerG.step()
        
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        if i % 10 == 0:
            prog = progress_bar(i, len(dataloader))
            print(f'\r에폭 [{epoch+1}/{num_epochs}] 배치 {prog} '
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}', end='')
    
    print()


In [None]:
with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    img_list.append(fake)
    
      img = visualize_images(fake, title=f'에폭 {epoch+1} 생성 결과')
    
    plt.savefig(f'results/fake_dogs_epoch_{epoch+1}.png')
    plt.close()
    
      if (epoch + 1) % 5 == 0 or (epoch + 1) == num_epochs:
        torch.save({
            'generator': netG.state_dict(),
            'discriminator': netD.state_dict(),
            'optimizerG': optimizerG.state_dict(),
            'optimizerD': optimizerD.state_dict(),
            'epoch': epoch,
            'G_losses': G_losses,
            'D_losses': D_losses,
        }, f'checkpoints/gan_model_epoch_{epoch+1}.pth')
        print(f"모델 체크포인트 저장: 에폭 {epoch+1}")
    elapsed = time.time() - start_time
    print(f'에폭 {epoch+1} 완료, 소요 시간: {elapsed:.2f}초')
print("학습 완료!")


In [None]:
plt.figure(figsize=(10, 5))
plt.title("생성자와 판별자의 손실")
plt.plot(G_losses, label="생성자")
plt.plot(D_losses, label="판별자")
plt.xlabel("반복")
plt.ylabel("손실")
plt.legend()
plt.savefig('results/loss_plot.png')
plt.show()
plt.close()

plt.figure(figsize=(12, 12))
rows = int(np.sqrt(num_epochs))
cols = int(np.ceil(num_epochs / rows))


In [None]:
for i in range(min(num_epochs, len(img_list))):
    plt.subplot(rows, cols, i + 1)
    plt.axis('off')
    plt.title(f'에폭 {i+1}')
    
    if i < len(img_list):
        img = torchvision.utils.make_grid(img_list[i][:16], padding=2, normalize=True)
        plt.imshow(np.transpose(img.cpu().numpy(), (1, 2, 0)))
        
plt.tight_layout()
plt.savefig('results/progress.png')
plt.show()


In [None]:
def generate_new_dogs(num_images=16):
    netG.eval()
    
    with torch.no_grad():
        noise = torch.randn(num_images, nz, 1, 1, device=device)
        
        fake_dogs = netG(noise).detach().cpu()
        
        visualize_images(fake_dogs, title=f'생성된 강아지 이미지 {num_images}개')
        plt.savefig('results/final_generated_dogs.png')
        
        print(f"{num_images}개의 새로운 강아지 이미지를 생성했습니다. ('results/final_generated_dogs.png'에 저장)")
        
        return fake_dogs

print("\n최종 모델로 새 이미지 생성:")
generate_new_dogs(16)


In [None]:
def download_results():
    print("\n생성된 결과 다운로드:")
    print("1. 왼쪽 사이드바에서 파일 탭(📁)을 클릭합니다.")
    print("2. 'results' 폴더에서 생성된 이미지를 다운로드합니다.")
    print("또는 아래 코드를 실행하여 결과 파일을 다운로드할 수 있습니다:")
    print("files.download('results/final_generated_dogs.png')")
    print("files.download('results/progress.png')")
    print("files.download('results/loss_plot.png')")
print("\n=== GAN 실습 완료 ===")
print("생성된 모든 결과는 'results' 폴더에 저장되었습니다.")
download_results()`
